๋ฐ์ํ
PyTorch์์ ๋ชจ๋ธ ์ผ๋ถ๋ถ์ ํ์ต์๋ง๊ธฐ ์ํด์๋ ํด๋น ๋ถ๋ถ์ ๋งค๊ฐ๋ณ์์ ๋ํด requires_grad ์์ฑ์ False๋ก ์ค์ ํ๋ฉด ๋๋ค. ์ด๋ฅผ ํตํด ๊ทธ ๋งค๊ฐ๋ณ์์ ๋ํ gradient๊ฐ ๊ณ์ฐ๋์ง ์์ ์จ์ดํธ๊ฐ ์
๋ฐ์ดํธ๋์ง ์๋๋ค.
๊ธฐ๋ณธ์ ์ผ๋ก model.named_parameters() ๋ฉ์๋๋ฅผ ์ฌ์ฉํ์ฌ name๊ณผ param์ ์ถ์ถํ๊ณ ์ํ๋ name์ ๋ ์ด์ด๋ง requires_grad๋ฅผ ๋ณ๊ฒฝํ ์ ์๋ค.
๋ชจ๋ธ ํ๋ฆฌ์ง ์์
- Resnet ์์
import torch
import torch.nn as nn
import torchvision.models as models
# ResNet-50 ๋ชจ๋ธ ๋ถ๋ฌ์ค๊ธฐ (pre-trained)
resnet50 = models.resnet50(pretrained=True)
# layer1, layer2, layer3์ ๋งค๊ฐ๋ณ์์ ๋ํด requires_grad๋ฅผ False๋ก ์ค์
for name, param in resnet50.named_parameters():
if 'layer1' in name or 'layer2' in name or 'layer3' in name:
param.requires_grad = False
- resnet50์ 'layer1', 'layer2', 'layer3' ๋ถ๋ถ๋ง required_grad ๋ฅผ False๋ก ๋ณ๊ฒฝํ๋ ์์
- Densenet ์์
model = models.densenet121(pretrained=True)
print(model.state_dict().keys()) # model ์ key๊ฐ์ ํ์ธ
for param in model.parameters(): # model์ ๋ชจ๋ parameter ๋ฅผ freeze
param.requires_grad = False
for param in model.features.denseblock4.denselayer16.parameters(): # ๋ด๊ฐ training ์ํค๊ธฐ ์ํ๋ ๋ถ๋ถ๋ง unfreeze
param.requires_grad = True
model.classifier = nn.Linear(1024,num_of_class) # classifier ๋ถ๋ถ ์ฌ์ค์
- ๋ชจ๋ธ ์จ์ดํธ ์ ์ฒด required_grad ๋ฅผ False๋ก ๋ณ๊ฒฝํ๊ณ ํ์ตํ๊ณ ์ถ์ ๋ ์ด์ด๋ง True๋ก ๋ณ๊ฒฝํด์ ์ฌ์ฉ
๋ฐ์ํ
'๐ป Programming > AI & ML' ์นดํ ๊ณ ๋ฆฌ์ ๋ค๋ฅธ ๊ธ
[pytorch] pytorch ๋ชจ๋ธ ๋ก๋ ์ค Missing key(s) in state_dict ์๋ฌ (0) | 2022.12.15 |
---|---|
[pytorch] COCO Data Format ์ ์ฉ Custom Dataset ์์ฑ (1) | 2022.06.04 |
[pytorch] model ์ ์ ๊ทผํ๊ธฐ, ํน์ layer ๋ณ๊ฒฝํ๊ธฐ (0) | 2022.01.05 |
[pytorch] Custom dataset, dataloader ๋ง๋ค๊ธฐ (0) | 2022.01.02 |
[pytorch] DataParallel ๋ก ํ์ตํ ๋ชจ๋ธ load (0) | 2021.02.17 |