๋ณธ๋ฌธ ๋ฐ”๋กœ๊ฐ€๊ธฐ
๐Ÿ’ป Programming/AI & ML

[pytorch] ๋ชจ๋ธ ์ผ๋ถ€๋ถ„๋งŒ ์ €์žฅํ•˜๊ธฐ/๋ถˆ๋Ÿฌ์˜ค๊ธฐ

by ๋ญ…์ฆค 2023. 12. 9.
๋ฐ˜์‘ํ˜•

PyTorch์—์„œ๋Š” ๋ชจ๋ธ์˜ ์ผ๋ถ€๋ถ„๋งŒ ์ €์žฅํ•˜๊ฑฐ๋‚˜ ๋ถˆ๋Ÿฌ์˜ฌ ์ˆ˜ ์žˆ๋‹ค. ์ด ๋ฐฉ๋ฒ•์€ ๋ชจ๋ธ์˜ ํŠน์ • ๋ถ€๋ถ„์— ๋Œ€ํ•œ ์ ‘๊ทผ ๊ถŒํ•œ์„ ์ œ์–ดํ•˜๊ณ  ๋ชจ๋ธ์˜ ์ผ๋ถ€๋ถ„๋งŒ์„ ๊ด€๋ฆฌํ•˜๊ณ ์ž ํ•  ๋•Œ ์œ ์šฉํ•˜๋‹ค. 

๋ชจ๋ธ์˜ ํŠน์ • ๋ถ€๋ถ„์„ ์ €์žฅํ•˜๊ณ  ๋ถˆ๋Ÿฌ์˜ค๊ธฐ ์œ„ํ•ด PyTorch์—์„œ๋Š” state_dict์˜ ํŠน์ • ํ‚ค๋ฅผ ์‚ฌ์šฉํ•  ์ˆ˜ ์žˆ๋‹ค.

(state_dict๋Š” ๋ชจ๋ธ์˜ ๊ฐ€์ค‘์น˜์™€ ํŽธํ–ฅ ๋“ฑ์„ ํฌํ•จํ•˜๋Š” ์‚ฌ์ „(dictionary)์ด๋‹ค)

๋‹ค์Œ์€ ๋ชจ๋ธ์˜ ํŠน์ • ๋ถ€๋ถ„์„ ์ €์žฅํ•˜๊ณ  ๋ถˆ๋Ÿฌ์˜ค๋Š” ๊ฐ„๋‹จํ•œ ์˜ˆ์ œ์ด๋‹ค.

import torch
import torch.nn as nn

# ์˜ˆ์ œ ๋ชจ๋ธ ์ •์˜
class ExampleModel(nn.Module):
    def __init__(self):
        super(ExampleModel, self).__init__()
        self.layer1 = nn.Linear(10, 5)
        self.layer2 = nn.Linear(5, 2)

# ๋ชจ๋ธ ์ธ์Šคํ„ด์Šค ์ƒ์„ฑ
model = ExampleModel()

# ๋ชจ๋ธ์˜ state_dict ๊ฐ€์ ธ์˜ค๊ธฐ
state_dict = model.state_dict()

# state_dict์˜ ํ‚ค ์ถœ๋ ฅ
print("Keys in the state_dict:")
for key in state_dict:
    print(key)


# ์›ํ•˜๋Š” ๋ถ€๋ถ„์„ ์ €์žฅ
torch.save({'layer1_weights': model.layer1.weight,
            'layer1_bias': model.layer1.bias}, 'partial_model.pth')

# ์ €์žฅ๋œ ๋ถ€๋ถ„์„ ๋ถˆ๋Ÿฌ์˜ค๊ธฐ
partial_state_dict = torch.load('partial_model.pth')
model.layer1.weight.data.copy_(partial_state_dict['layer1_weights'])
model.layer1.bias.data.copy_(partial_state_dict['layer1_bias'])
  • model.layer1.weight.data.copy_(partial_state_dict['layer1_weights'])
    • ์ด ์ฝ”๋“œ๋Š” model์˜ ์ฒซ ๋ฒˆ์งธ ๋ ˆ์ด์–ด์ธ layer1์˜ ๊ฐ€์ค‘์น˜๋ฅผ partial_state_dict์—์„œ ๊ฐ€์ ธ์˜จ layer1_weights์˜ ๊ฐ’์œผ๋กœ ๋ณต์‚ฌํ•œ๋‹ค.
    • data๋Š” ํ…์„œ์˜ ๋ฐ์ดํ„ฐ๋ฅผ ์ง์ ‘์ ์œผ๋กœ ์ ‘๊ทผํ•˜๊ธฐ ์œ„ํ•œ ์†์„ฑ์ด๋ฉฐ, .copy_()๋Š” ๊ฐ’์„ ๋ณต์‚ฌํ•˜๋Š” ๋ฉ”์„œ๋“œ์ด๋‹ค.
    • ๋”ฐ๋ผ์„œ, layer1_weights์˜ ๊ฐ’์„ layer1์˜ ๊ฐ€์ค‘์น˜์— ๋ณต์‚ฌํ•œ๋‹ค๋Š” ๋œป์ด๋‹ค.
  • ์ด๋Ÿฐ ์‹์œผ๋กœ ๋ชจ๋ธ์˜ ํŠน์ • ๋ถ€๋ถ„์— ๋Œ€ํ•œ ๊ฐ’์„ ๋ณต์‚ฌํ•จ์œผ๋กœ์จ, ์ „์ฒด ๋ชจ๋ธ์ด ์•„๋‹Œ ์ผ๋ถ€๋ถ„๋งŒ์„ ์—…๋ฐ์ดํŠธํ•˜๊ฑฐ๋‚˜ ๋ถˆ๋Ÿฌ์˜ฌ ์ˆ˜ ์žˆ๋‹ค. 

 

๋ชจ๋ธ์ด ๋ณต์žกํ•ด์„œ state dict์˜ ํŠน์ • ํ‚ค๊ฐ’๋งŒ ์ €์žฅํ•˜๊ณ  ๋ถˆ๋Ÿฌ์˜ค๊ณ  ์‹ถ์„ ๋• ๋‹ค์Œ๊ณผ ๊ฐ™์ด ์ฝ”๋“œ๋ฅผ ์ž‘์„ฑํ•  ์ˆ˜ ์žˆ๋‹ค.

import torch

# ๋ชจ๋ธ ์ธ์Šคํ„ด์Šค ์ƒ์„ฑ
model = ...

# ์ €์žฅํ•  ๋•Œ ํŠน์ • ํ‚ค ํ•„ํ„ฐ๋ง
filtered_state_dict = {k: v for k, v in model.state_dict().items() if k.startswith('backbone')}

# ํ•„ํ„ฐ๋ง๋œ state_dict ์ €์žฅ
torch.save(filtered_state_dict, 'backbone_model.pth')

# ์ €์žฅ๋œ state_dict ๋ถˆ๋Ÿฌ์˜ค๊ธฐ
filtered_state_dict = torch.load('backbone_model.pth')

# ํ˜„์žฌ ๋ชจ๋ธ์˜ state_dict ๊ฐ€์ ธ์˜ค๊ธฐ
current_state_dict = model.state_dict()

# ๋ถˆ๋Ÿฌ์˜จ state_dict๋ฅผ ํ˜„์žฌ ๋ชจ๋ธ์— ์ ์šฉ (์ผ์น˜ํ•˜๋Š” ํ‚ค๋งŒ)
for key, value in filtered_state_dict.items():
    current_state_dict[key] = value

# ๋ชจ๋ธ์— ์ ์šฉ๋œ state_dict ์„ค์ •
model.load_state_dict(current_state_dict)
  • model์˜ state dict ํ‚ค ๊ฐ’์ค‘ 'backbone'์œผ๋กœ ์‹œ์ž‘ํ•˜๋Š” ์›จ์ดํŠธ๋งŒ ์ €์žฅ
  • ๋ถˆ๋Ÿฌ์˜จ ๋ชจ๋ธ ์›จ์ดํŠธ์˜ ์ผ๋ถ€๋ฅผ ํ˜„์žฌ ๋ชจ๋ธ์˜ state dict์˜ ์ผ์น˜ํ•˜๋Š” ํ‚ค๊ฐ’์— ์ ์šฉ
๋ฐ˜์‘ํ˜•