๋ฐ์ํ
PyTorch์์๋ ๋ชจ๋ธ์ ์ผ๋ถ๋ถ๋ง ์ ์ฅํ๊ฑฐ๋ ๋ถ๋ฌ์ฌ ์ ์๋ค. ์ด ๋ฐฉ๋ฒ์ ๋ชจ๋ธ์ ํน์ ๋ถ๋ถ์ ๋ํ ์ ๊ทผ ๊ถํ์ ์ ์ดํ๊ณ ๋ชจ๋ธ์ ์ผ๋ถ๋ถ๋ง์ ๊ด๋ฆฌํ๊ณ ์ ํ ๋ ์ ์ฉํ๋ค.
๋ชจ๋ธ์ ํน์ ๋ถ๋ถ์ ์ ์ฅํ๊ณ ๋ถ๋ฌ์ค๊ธฐ ์ํด PyTorch์์๋ state_dict์ ํน์ ํค๋ฅผ ์ฌ์ฉํ ์ ์๋ค.
(state_dict๋ ๋ชจ๋ธ์ ๊ฐ์ค์น์ ํธํฅ ๋ฑ์ ํฌํจํ๋ ์ฌ์ (dictionary)์ด๋ค)
๋ค์์ ๋ชจ๋ธ์ ํน์ ๋ถ๋ถ์ ์ ์ฅํ๊ณ ๋ถ๋ฌ์ค๋ ๊ฐ๋จํ ์์ ์ด๋ค.
import torch
import torch.nn as nn
# ์์ ๋ชจ๋ธ ์ ์
class ExampleModel(nn.Module):
def __init__(self):
super(ExampleModel, self).__init__()
self.layer1 = nn.Linear(10, 5)
self.layer2 = nn.Linear(5, 2)
# ๋ชจ๋ธ ์ธ์คํด์ค ์์ฑ
model = ExampleModel()
# ๋ชจ๋ธ์ state_dict ๊ฐ์ ธ์ค๊ธฐ
state_dict = model.state_dict()
# state_dict์ ํค ์ถ๋ ฅ
print("Keys in the state_dict:")
for key in state_dict:
print(key)
# ์ํ๋ ๋ถ๋ถ์ ์ ์ฅ
torch.save({'layer1_weights': model.layer1.weight,
'layer1_bias': model.layer1.bias}, 'partial_model.pth')
# ์ ์ฅ๋ ๋ถ๋ถ์ ๋ถ๋ฌ์ค๊ธฐ
partial_state_dict = torch.load('partial_model.pth')
model.layer1.weight.data.copy_(partial_state_dict['layer1_weights'])
model.layer1.bias.data.copy_(partial_state_dict['layer1_bias'])
- model.layer1.weight.data.copy_(partial_state_dict['layer1_weights'])
- ์ด ์ฝ๋๋ model์ ์ฒซ ๋ฒ์งธ ๋ ์ด์ด์ธ layer1์ ๊ฐ์ค์น๋ฅผ partial_state_dict์์ ๊ฐ์ ธ์จ layer1_weights์ ๊ฐ์ผ๋ก ๋ณต์ฌํ๋ค.
- data๋ ํ ์์ ๋ฐ์ดํฐ๋ฅผ ์ง์ ์ ์ผ๋ก ์ ๊ทผํ๊ธฐ ์ํ ์์ฑ์ด๋ฉฐ, .copy_()๋ ๊ฐ์ ๋ณต์ฌํ๋ ๋ฉ์๋์ด๋ค.
- ๋ฐ๋ผ์, layer1_weights์ ๊ฐ์ layer1์ ๊ฐ์ค์น์ ๋ณต์ฌํ๋ค๋ ๋ป์ด๋ค.
- ์ด๋ฐ ์์ผ๋ก ๋ชจ๋ธ์ ํน์ ๋ถ๋ถ์ ๋ํ ๊ฐ์ ๋ณต์ฌํจ์ผ๋ก์จ, ์ ์ฒด ๋ชจ๋ธ์ด ์๋ ์ผ๋ถ๋ถ๋ง์ ์ ๋ฐ์ดํธํ๊ฑฐ๋ ๋ถ๋ฌ์ฌ ์ ์๋ค.
๋ชจ๋ธ์ด ๋ณต์กํด์ state dict์ ํน์ ํค๊ฐ๋ง ์ ์ฅํ๊ณ ๋ถ๋ฌ์ค๊ณ ์ถ์ ๋ ๋ค์๊ณผ ๊ฐ์ด ์ฝ๋๋ฅผ ์์ฑํ ์ ์๋ค.
import torch
# ๋ชจ๋ธ ์ธ์คํด์ค ์์ฑ
model = ...
# ์ ์ฅํ ๋ ํน์ ํค ํํฐ๋ง
filtered_state_dict = {k: v for k, v in model.state_dict().items() if k.startswith('backbone')}
# ํํฐ๋ง๋ state_dict ์ ์ฅ
torch.save(filtered_state_dict, 'backbone_model.pth')
# ์ ์ฅ๋ state_dict ๋ถ๋ฌ์ค๊ธฐ
filtered_state_dict = torch.load('backbone_model.pth')
# ํ์ฌ ๋ชจ๋ธ์ state_dict ๊ฐ์ ธ์ค๊ธฐ
current_state_dict = model.state_dict()
# ๋ถ๋ฌ์จ state_dict๋ฅผ ํ์ฌ ๋ชจ๋ธ์ ์ ์ฉ (์ผ์นํ๋ ํค๋ง)
for key, value in filtered_state_dict.items():
current_state_dict[key] = value
# ๋ชจ๋ธ์ ์ ์ฉ๋ state_dict ์ค์
model.load_state_dict(current_state_dict)
- model์ state dict ํค ๊ฐ์ค 'backbone'์ผ๋ก ์์ํ๋ ์จ์ดํธ๋ง ์ ์ฅ
- ๋ถ๋ฌ์จ ๋ชจ๋ธ ์จ์ดํธ์ ์ผ๋ถ๋ฅผ ํ์ฌ ๋ชจ๋ธ์ state dict์ ์ผ์นํ๋ ํค๊ฐ์ ์ ์ฉ
๋ฐ์ํ