๋ณธ๋ฌธ ๋ฐ”๋กœ๊ฐ€๊ธฐ
๐Ÿ’ป Programming/AI & ML

[pytorch] ๋ชจ๋ธ์˜ ์ผ๋ถ€ ๋ ˆ์ด์–ด ์›จ์ดํŠธ ์—…๋ฐ์ดํŠธ ๋ง‰๊ธฐ | model freezing (๋ชจ๋ธ ํ”„๋ฆฌ์ง•)

by ๋ญ…์ฆค 2021. 2. 17.
๋ฐ˜์‘ํ˜•

PyTorch์—์„œ ๋ชจ๋ธ ์ผ๋ถ€๋ถ„์˜ ํ•™์Šต์„๋ง‰๊ธฐ ์œ„ํ•ด์„œ๋Š” ํ•ด๋‹น ๋ถ€๋ถ„์˜ ๋งค๊ฐœ๋ณ€์ˆ˜์— ๋Œ€ํ•ด requires_grad ์†์„ฑ์„ False๋กœ ์„ค์ •ํ•˜๋ฉด ๋œ๋‹ค. ์ด๋ฅผ ํ†ตํ•ด ๊ทธ ๋งค๊ฐœ๋ณ€์ˆ˜์— ๋Œ€ํ•œ gradient๊ฐ€ ๊ณ„์‚ฐ๋˜์ง€ ์•Š์•„ ์›จ์ดํŠธ๊ฐ€ ์—…๋ฐ์ดํŠธ๋˜์ง€ ์•Š๋Š”๋‹ค.

๊ธฐ๋ณธ์ ์œผ๋กœ model.named_parameters() ๋ฉ”์„œ๋“œ๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ name๊ณผ param์„ ์ถ”์ถœํ•˜๊ณ  ์›ํ•˜๋Š” name์˜ ๋ ˆ์ด์–ด๋งŒ requires_grad๋ฅผ ๋ณ€๊ฒฝํ•  ์ˆ˜ ์žˆ๋‹ค.

 

๋ชจ๋ธ ํ”„๋ฆฌ์ง• ์˜ˆ์‹œ

- Resnet ์˜ˆ์‹œ

import torch
import torch.nn as nn
import torchvision.models as models

# ResNet-50 ๋ชจ๋ธ ๋ถˆ๋Ÿฌ์˜ค๊ธฐ (pre-trained)
resnet50 = models.resnet50(pretrained=True)

# layer1, layer2, layer3์˜ ๋งค๊ฐœ๋ณ€์ˆ˜์— ๋Œ€ํ•ด requires_grad๋ฅผ False๋กœ ์„ค์ •
for name, param in resnet50.named_parameters():
    if 'layer1' in name or 'layer2' in name or 'layer3' in name:
        param.requires_grad = False
  • resnet50์˜ 'layer1', 'layer2', 'layer3' ๋ถ€๋ถ„๋งŒ required_grad ๋ฅผ False๋กœ ๋ณ€๊ฒฝํ•˜๋Š” ์˜ˆ์‹œ

 

 

 

- Densenet ์˜ˆ์‹œ

model = models.densenet121(pretrained=True)
print(model.state_dict().keys()) # model ์˜ key๊ฐ’์„ ํ™•์ธ

for param in model.parameters(): # model์˜ ๋ชจ๋“  parameter ๋ฅผ freeze
    param.requires_grad = False

for param in model.features.denseblock4.denselayer16.parameters(): # ๋‚ด๊ฐ€ training ์‹œํ‚ค๊ธฐ ์›ํ•˜๋Š” ๋ถ€๋ถ„๋งŒ unfreeze
    param.requires_grad = True

model.classifier = nn.Linear(1024,num_of_class) # classifier ๋ถ€๋ถ„ ์žฌ์„ค์ •
  • ๋ชจ๋ธ ์›จ์ดํŠธ ์ „์ฒด required_grad ๋ฅผ False๋กœ ๋ณ€๊ฒฝํ•˜๊ณ  ํ•™์Šตํ•˜๊ณ  ์‹ถ์€ ๋ ˆ์ด์–ด๋งŒ True๋กœ ๋ณ€๊ฒฝํ•ด์„œ ์‚ฌ์šฉ

 

๋ฐ˜์‘ํ˜•