๐ป Programming/AI & ML
[pytorch] ๋ชจ๋ธ์ ์ผ๋ถ ๋ ์ด์ด ์จ์ดํธ ์ ๋ฐ์ดํธ ๋ง๊ธฐ | model freezing (๋ชจ๋ธ ํ๋ฆฌ์ง)
๋ญ
์ฆค
2021. 2. 17. 00:39
๋ฐ์ํ
PyTorch์์ ๋ชจ๋ธ ์ผ๋ถ๋ถ์ ํ์ต์๋ง๊ธฐ ์ํด์๋ ํด๋น ๋ถ๋ถ์ ๋งค๊ฐ๋ณ์์ ๋ํด requires_grad ์์ฑ์ False๋ก ์ค์ ํ๋ฉด ๋๋ค. ์ด๋ฅผ ํตํด ๊ทธ ๋งค๊ฐ๋ณ์์ ๋ํ gradient๊ฐ ๊ณ์ฐ๋์ง ์์ ์จ์ดํธ๊ฐ ์
๋ฐ์ดํธ๋์ง ์๋๋ค.
๊ธฐ๋ณธ์ ์ผ๋ก model.named_parameters() ๋ฉ์๋๋ฅผ ์ฌ์ฉํ์ฌ name๊ณผ param์ ์ถ์ถํ๊ณ ์ํ๋ name์ ๋ ์ด์ด๋ง requires_grad๋ฅผ ๋ณ๊ฒฝํ ์ ์๋ค.
๋ชจ๋ธ ํ๋ฆฌ์ง ์์
- Resnet ์์
import torch
import torch.nn as nn
import torchvision.models as models
# ResNet-50 ๋ชจ๋ธ ๋ถ๋ฌ์ค๊ธฐ (pre-trained)
resnet50 = models.resnet50(pretrained=True)
# layer1, layer2, layer3์ ๋งค๊ฐ๋ณ์์ ๋ํด requires_grad๋ฅผ False๋ก ์ค์
for name, param in resnet50.named_parameters():
if 'layer1' in name or 'layer2' in name or 'layer3' in name:
param.requires_grad = False
- resnet50์ 'layer1', 'layer2', 'layer3' ๋ถ๋ถ๋ง required_grad ๋ฅผ False๋ก ๋ณ๊ฒฝํ๋ ์์
- Densenet ์์
model = models.densenet121(pretrained=True)
print(model.state_dict().keys()) # model ์ key๊ฐ์ ํ์ธ
for param in model.parameters(): # model์ ๋ชจ๋ parameter ๋ฅผ freeze
param.requires_grad = False
for param in model.features.denseblock4.denselayer16.parameters(): # ๋ด๊ฐ training ์ํค๊ธฐ ์ํ๋ ๋ถ๋ถ๋ง unfreeze
param.requires_grad = True
model.classifier = nn.Linear(1024,num_of_class) # classifier ๋ถ๋ถ ์ฌ์ค์
- ๋ชจ๋ธ ์จ์ดํธ ์ ์ฒด required_grad ๋ฅผ False๋ก ๋ณ๊ฒฝํ๊ณ ํ์ตํ๊ณ ์ถ์ ๋ ์ด์ด๋ง True๋ก ๋ณ๊ฒฝํด์ ์ฌ์ฉ
๋ฐ์ํ