[pytorch] ๋ชจ๋ธ์˜ ์ผ๋ถ€ ๋ ˆ์ด์–ด ์›จ์ดํŠธ ์—…๋ฐ์ดํŠธ ๋ง‰๊ธฐ | model freezing (๋ชจ๋ธ ํ”„๋ฆฌ์ง•)

2021. 2. 17. 00:39ยท๐Ÿ’ป Programming/AI & ML
๋ฐ˜์‘ํ˜•

PyTorch์—์„œ ๋ชจ๋ธ ์ผ๋ถ€๋ถ„์˜ ํ•™์Šต์„๋ง‰๊ธฐ ์œ„ํ•ด์„œ๋Š” ํ•ด๋‹น ๋ถ€๋ถ„์˜ ๋งค๊ฐœ๋ณ€์ˆ˜์— ๋Œ€ํ•ด requires_grad ์†์„ฑ์„ False๋กœ ์„ค์ •ํ•˜๋ฉด ๋œ๋‹ค. ์ด๋ฅผ ํ†ตํ•ด ๊ทธ ๋งค๊ฐœ๋ณ€์ˆ˜์— ๋Œ€ํ•œ gradient๊ฐ€ ๊ณ„์‚ฐ๋˜์ง€ ์•Š์•„ ์›จ์ดํŠธ๊ฐ€ ์—…๋ฐ์ดํŠธ๋˜์ง€ ์•Š๋Š”๋‹ค.

๊ธฐ๋ณธ์ ์œผ๋กœ model.named_parameters() ๋ฉ”์„œ๋“œ๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ name๊ณผ param์„ ์ถ”์ถœํ•˜๊ณ  ์›ํ•˜๋Š” name์˜ ๋ ˆ์ด์–ด๋งŒ requires_grad๋ฅผ ๋ณ€๊ฒฝํ•  ์ˆ˜ ์žˆ๋‹ค.

 

๋ชจ๋ธ ํ”„๋ฆฌ์ง• ์˜ˆ์‹œ

- Resnet ์˜ˆ์‹œ

import torch
import torch.nn as nn
import torchvision.models as models

# ResNet-50 ๋ชจ๋ธ ๋ถˆ๋Ÿฌ์˜ค๊ธฐ (pre-trained)
resnet50 = models.resnet50(pretrained=True)

# layer1, layer2, layer3์˜ ๋งค๊ฐœ๋ณ€์ˆ˜์— ๋Œ€ํ•ด requires_grad๋ฅผ False๋กœ ์„ค์ •
for name, param in resnet50.named_parameters():
    if 'layer1' in name or 'layer2' in name or 'layer3' in name:
        param.requires_grad = False
  • resnet50์˜ 'layer1', 'layer2', 'layer3' ๋ถ€๋ถ„๋งŒ required_grad ๋ฅผ False๋กœ ๋ณ€๊ฒฝํ•˜๋Š” ์˜ˆ์‹œ

 

 

 

- Densenet ์˜ˆ์‹œ

model = models.densenet121(pretrained=True)
print(model.state_dict().keys()) # model ์˜ key๊ฐ’์„ ํ™•์ธ

for param in model.parameters(): # model์˜ ๋ชจ๋“  parameter ๋ฅผ freeze
    param.requires_grad = False

for param in model.features.denseblock4.denselayer16.parameters(): # ๋‚ด๊ฐ€ training ์‹œํ‚ค๊ธฐ ์›ํ•˜๋Š” ๋ถ€๋ถ„๋งŒ unfreeze
    param.requires_grad = True

model.classifier = nn.Linear(1024,num_of_class) # classifier ๋ถ€๋ถ„ ์žฌ์„ค์ •
  • ๋ชจ๋ธ ์›จ์ดํŠธ ์ „์ฒด required_grad ๋ฅผ False๋กœ ๋ณ€๊ฒฝํ•˜๊ณ  ํ•™์Šตํ•˜๊ณ  ์‹ถ์€ ๋ ˆ์ด์–ด๋งŒ True๋กœ ๋ณ€๊ฒฝํ•ด์„œ ์‚ฌ์šฉ

 

๋ฐ˜์‘ํ˜•

'๐Ÿ’ป Programming > AI & ML' ์นดํ…Œ๊ณ ๋ฆฌ์˜ ๋‹ค๋ฅธ ๊ธ€

[pytorch] pytorch ๋ชจ๋ธ ๋กœ๋“œ ์ค‘ Missing key(s) in state_dict ์—๋Ÿฌ  (0) 2022.12.15
[pytorch] COCO Data Format ์ „์šฉ Custom Dataset ์ƒ์„ฑ  (1) 2022.06.04
[pytorch] model ์— ์ ‘๊ทผํ•˜๊ธฐ, ํŠน์ • layer ๋ณ€๊ฒฝํ•˜๊ธฐ  (0) 2022.01.05
[pytorch] Custom dataset, dataloader ๋งŒ๋“ค๊ธฐ  (0) 2022.01.02
[pytorch] DataParallel ๋กœ ํ•™์Šตํ•œ ๋ชจ๋ธ load  (0) 2021.02.17
'๐Ÿ’ป Programming/AI & ML' ์นดํ…Œ๊ณ ๋ฆฌ์˜ ๋‹ค๋ฅธ ๊ธ€
  • [pytorch] COCO Data Format ์ „์šฉ Custom Dataset ์ƒ์„ฑ
  • [pytorch] model ์— ์ ‘๊ทผํ•˜๊ธฐ, ํŠน์ • layer ๋ณ€๊ฒฝํ•˜๊ธฐ
  • [pytorch] Custom dataset, dataloader ๋งŒ๋“ค๊ธฐ
  • [pytorch] DataParallel ๋กœ ํ•™์Šตํ•œ ๋ชจ๋ธ load
๋ญ…์ฆค
๋ญ…์ฆค
AI ๊ธฐ์ˆ  ๋ธ”๋กœ๊ทธ
    ๋ฐ˜์‘ํ˜•
  • ๋ญ…์ฆค
    moovzi’s Doodle
    ๋ญ…์ฆค
  • ์ „์ฒด
    ์˜ค๋Š˜
    ์–ด์ œ
  • ๊ณต์ง€์‚ฌํ•ญ

    • โœจ About Me
    • ๋ถ„๋ฅ˜ ์ „์ฒด๋ณด๊ธฐ (213)
      • ๐Ÿ“– Fundamentals (34)
        • Computer Vision (9)
        • 3D vision & Graphics (6)
        • AI & ML (16)
        • NLP (2)
        • etc. (1)
      • ๐Ÿ› Research (75)
        • Deep Learning (7)
        • Perception (19)
        • OCR (7)
        • Multi-modal (5)
        • Image•Video Generation (18)
        • 3D Vision (4)
        • Material • Texture Recognit.. (8)
        • Large-scale Model (7)
        • etc. (0)
      • ๐Ÿ› ๏ธ Engineering (8)
        • Distributed Training & Infe.. (5)
        • AI & ML ์ธ์‚ฌ์ดํŠธ (3)
      • ๐Ÿ’ป Programming (92)
        • Python (18)
        • Computer Vision (12)
        • LLM (4)
        • AI & ML (18)
        • Database (3)
        • Distributed Computing (6)
        • Apache Airflow (6)
        • Docker & Kubernetes (14)
        • ์ฝ”๋”ฉ ํ…Œ์ŠคํŠธ (4)
        • C++ (1)
        • etc. (6)
      • ๐Ÿ’ฌ ETC (4)
        • ์ฑ… ๋ฆฌ๋ทฐ (4)
  • ๋งํฌ

    • ๋ฆฌํ‹€๋ฆฌ ํ”„๋กœํ•„ (๋ฉ˜ํ† ๋ง, ๋ฉด์ ‘์ฑ…,...)
    • ใ€Ž๋‚˜๋Š” AI ์—”์ง€๋‹ˆ์–ด์ž…๋‹ˆ๋‹คใ€
    • Instagram
    • Brunch
    • Github
  • ์ธ๊ธฐ ๊ธ€

  • ์ตœ๊ทผ ๋Œ“๊ธ€

  • ์ตœ๊ทผ ๊ธ€

  • hELLOยท Designed By์ •์ƒ์šฐ.v4.10.3
๋ญ…์ฆค
[pytorch] ๋ชจ๋ธ์˜ ์ผ๋ถ€ ๋ ˆ์ด์–ด ์›จ์ดํŠธ ์—…๋ฐ์ดํŠธ ๋ง‰๊ธฐ | model freezing (๋ชจ๋ธ ํ”„๋ฆฌ์ง•)
์ƒ๋‹จ์œผ๋กœ

ํ‹ฐ์Šคํ† ๋ฆฌํˆด๋ฐ”