[pytorch] ๋ชจ๋ธ ์ผ๋ถ€๋ถ„๋งŒ ์ €์žฅํ•˜๊ธฐ/๋ถˆ๋Ÿฌ์˜ค๊ธฐ

2023. 12. 9. 15:31ยท๐Ÿ’ป Programming/AI & ML
๋ฐ˜์‘ํ˜•

PyTorch์—์„œ๋Š” ๋ชจ๋ธ์˜ ์ผ๋ถ€๋ถ„๋งŒ ์ €์žฅํ•˜๊ฑฐ๋‚˜ ๋ถˆ๋Ÿฌ์˜ฌ ์ˆ˜ ์žˆ๋‹ค. ์ด ๋ฐฉ๋ฒ•์€ ๋ชจ๋ธ์˜ ํŠน์ • ๋ถ€๋ถ„์— ๋Œ€ํ•œ ์ ‘๊ทผ ๊ถŒํ•œ์„ ์ œ์–ดํ•˜๊ณ  ๋ชจ๋ธ์˜ ์ผ๋ถ€๋ถ„๋งŒ์„ ๊ด€๋ฆฌํ•˜๊ณ ์ž ํ•  ๋•Œ ์œ ์šฉํ•˜๋‹ค. 

๋ชจ๋ธ์˜ ํŠน์ • ๋ถ€๋ถ„์„ ์ €์žฅํ•˜๊ณ  ๋ถˆ๋Ÿฌ์˜ค๊ธฐ ์œ„ํ•ด PyTorch์—์„œ๋Š” state_dict์˜ ํŠน์ • ํ‚ค๋ฅผ ์‚ฌ์šฉํ•  ์ˆ˜ ์žˆ๋‹ค.

(state_dict๋Š” ๋ชจ๋ธ์˜ ๊ฐ€์ค‘์น˜์™€ ํŽธํ–ฅ ๋“ฑ์„ ํฌํ•จํ•˜๋Š” ์‚ฌ์ „(dictionary)์ด๋‹ค)

๋‹ค์Œ์€ ๋ชจ๋ธ์˜ ํŠน์ • ๋ถ€๋ถ„์„ ์ €์žฅํ•˜๊ณ  ๋ถˆ๋Ÿฌ์˜ค๋Š” ๊ฐ„๋‹จํ•œ ์˜ˆ์ œ์ด๋‹ค.

import torch
import torch.nn as nn

# ์˜ˆ์ œ ๋ชจ๋ธ ์ •์˜
class ExampleModel(nn.Module):
    def __init__(self):
        super(ExampleModel, self).__init__()
        self.layer1 = nn.Linear(10, 5)
        self.layer2 = nn.Linear(5, 2)

# ๋ชจ๋ธ ์ธ์Šคํ„ด์Šค ์ƒ์„ฑ
model = ExampleModel()

# ๋ชจ๋ธ์˜ state_dict ๊ฐ€์ ธ์˜ค๊ธฐ
state_dict = model.state_dict()

# state_dict์˜ ํ‚ค ์ถœ๋ ฅ
print("Keys in the state_dict:")
for key in state_dict:
    print(key)


# ์›ํ•˜๋Š” ๋ถ€๋ถ„์„ ์ €์žฅ
torch.save({'layer1_weights': model.layer1.weight,
            'layer1_bias': model.layer1.bias}, 'partial_model.pth')

# ์ €์žฅ๋œ ๋ถ€๋ถ„์„ ๋ถˆ๋Ÿฌ์˜ค๊ธฐ
partial_state_dict = torch.load('partial_model.pth')
model.layer1.weight.data.copy_(partial_state_dict['layer1_weights'])
model.layer1.bias.data.copy_(partial_state_dict['layer1_bias'])
  • model.layer1.weight.data.copy_(partial_state_dict['layer1_weights'])
    • ์ด ์ฝ”๋“œ๋Š” model์˜ ์ฒซ ๋ฒˆ์งธ ๋ ˆ์ด์–ด์ธ layer1์˜ ๊ฐ€์ค‘์น˜๋ฅผ partial_state_dict์—์„œ ๊ฐ€์ ธ์˜จ layer1_weights์˜ ๊ฐ’์œผ๋กœ ๋ณต์‚ฌํ•œ๋‹ค.
    • data๋Š” ํ…์„œ์˜ ๋ฐ์ดํ„ฐ๋ฅผ ์ง์ ‘์ ์œผ๋กœ ์ ‘๊ทผํ•˜๊ธฐ ์œ„ํ•œ ์†์„ฑ์ด๋ฉฐ, .copy_()๋Š” ๊ฐ’์„ ๋ณต์‚ฌํ•˜๋Š” ๋ฉ”์„œ๋“œ์ด๋‹ค.
    • ๋”ฐ๋ผ์„œ, layer1_weights์˜ ๊ฐ’์„ layer1์˜ ๊ฐ€์ค‘์น˜์— ๋ณต์‚ฌํ•œ๋‹ค๋Š” ๋œป์ด๋‹ค.
  • ์ด๋Ÿฐ ์‹์œผ๋กœ ๋ชจ๋ธ์˜ ํŠน์ • ๋ถ€๋ถ„์— ๋Œ€ํ•œ ๊ฐ’์„ ๋ณต์‚ฌํ•จ์œผ๋กœ์จ, ์ „์ฒด ๋ชจ๋ธ์ด ์•„๋‹Œ ์ผ๋ถ€๋ถ„๋งŒ์„ ์—…๋ฐ์ดํŠธํ•˜๊ฑฐ๋‚˜ ๋ถˆ๋Ÿฌ์˜ฌ ์ˆ˜ ์žˆ๋‹ค. 

 

๋ชจ๋ธ์ด ๋ณต์žกํ•ด์„œ state dict์˜ ํŠน์ • ํ‚ค๊ฐ’๋งŒ ์ €์žฅํ•˜๊ณ  ๋ถˆ๋Ÿฌ์˜ค๊ณ  ์‹ถ์„ ๋• ๋‹ค์Œ๊ณผ ๊ฐ™์ด ์ฝ”๋“œ๋ฅผ ์ž‘์„ฑํ•  ์ˆ˜ ์žˆ๋‹ค.

import torch

# ๋ชจ๋ธ ์ธ์Šคํ„ด์Šค ์ƒ์„ฑ
model = ...

# ์ €์žฅํ•  ๋•Œ ํŠน์ • ํ‚ค ํ•„ํ„ฐ๋ง
filtered_state_dict = {k: v for k, v in model.state_dict().items() if k.startswith('backbone')}

# ํ•„ํ„ฐ๋ง๋œ state_dict ์ €์žฅ
torch.save(filtered_state_dict, 'backbone_model.pth')

# ์ €์žฅ๋œ state_dict ๋ถˆ๋Ÿฌ์˜ค๊ธฐ
filtered_state_dict = torch.load('backbone_model.pth')

# ํ˜„์žฌ ๋ชจ๋ธ์˜ state_dict ๊ฐ€์ ธ์˜ค๊ธฐ
current_state_dict = model.state_dict()

# ๋ถˆ๋Ÿฌ์˜จ state_dict๋ฅผ ํ˜„์žฌ ๋ชจ๋ธ์— ์ ์šฉ (์ผ์น˜ํ•˜๋Š” ํ‚ค๋งŒ)
for key, value in filtered_state_dict.items():
    current_state_dict[key] = value

# ๋ชจ๋ธ์— ์ ์šฉ๋œ state_dict ์„ค์ •
model.load_state_dict(current_state_dict)
  • model์˜ state dict ํ‚ค ๊ฐ’์ค‘ 'backbone'์œผ๋กœ ์‹œ์ž‘ํ•˜๋Š” ์›จ์ดํŠธ๋งŒ ์ €์žฅ
  • ๋ถˆ๋Ÿฌ์˜จ ๋ชจ๋ธ ์›จ์ดํŠธ์˜ ์ผ๋ถ€๋ฅผ ํ˜„์žฌ ๋ชจ๋ธ์˜ state dict์˜ ์ผ์น˜ํ•˜๋Š” ํ‚ค๊ฐ’์— ์ ์šฉ
๋ฐ˜์‘ํ˜•

'๐Ÿ’ป Programming > AI & ML' ์นดํ…Œ๊ณ ๋ฆฌ์˜ ๋‹ค๋ฅธ ๊ธ€

[ํŠœํ† ๋ฆฌ์–ผ] ๋ˆ„๊ตฌ๋‚˜ ์‚ฌ์šฉํ•  ์ˆ˜ ์žˆ๋Š” CLIP & KoCLIP ๋ชจ๋ธ ์˜ˆ์ œ | ์ฝ”๋”ฉ ๋ชปํ•ด๋„ ๊ฐ€๋Šฅํ•ด! | ๋ฉ€ํ‹ฐ๋ชจ๋‹ฌ AI ์˜ˆ์ œ | CLIP & ํ•œ๊ตญ์–ด CLIP  (0) 2024.07.28
[ONNX] ๋”ฅ๋Ÿฌ๋‹ ๋ชจ๋ธ ONNX Runtime์œผ๋กœ CPU ํ™˜๊ฒฝ์—์„œ ๊ฐ€์†ํ™”ํ•˜๊ธฐ  (0) 2023.11.16
[Model Inference] Pytorch 2.0 Compile ์‚ฌ์šฉ ํ›„๊ธฐ ๋ฐ ์žฅ๋‹จ์  | pytorch compile ๋ชจ๋ธ ์ถ”๋ก  ์†๋„ ๊ฐœ์„  ํ…Œ์ŠคํŠธ  (1) 2023.10.07
[Model Inference] Torch-TensorRT ์‚ฌ์šฉ๋ฒ• | ๋”ฅ๋Ÿฌ๋‹ ๋ชจ๋ธ ์ตœ์ ํ™” ๋ฐ ์ธํผ๋Ÿฐ์Šค ๊ฐ€์†ํ™”  (1) 2023.10.02
[pytorch] Multi-GPU Training | ๋‹ค์ค‘ GPU ํ•™์Šต ์˜ˆ์‹œ| Distributed Data Parallel (DDP) | Data Parallel (DP)  (0) 2023.04.17
'๐Ÿ’ป Programming/AI & ML' ์นดํ…Œ๊ณ ๋ฆฌ์˜ ๋‹ค๋ฅธ ๊ธ€
  • [ํŠœํ† ๋ฆฌ์–ผ] ๋ˆ„๊ตฌ๋‚˜ ์‚ฌ์šฉํ•  ์ˆ˜ ์žˆ๋Š” CLIP & KoCLIP ๋ชจ๋ธ ์˜ˆ์ œ | ์ฝ”๋”ฉ ๋ชปํ•ด๋„ ๊ฐ€๋Šฅํ•ด! | ๋ฉ€ํ‹ฐ๋ชจ๋‹ฌ AI ์˜ˆ์ œ | CLIP & ํ•œ๊ตญ์–ด CLIP
  • [ONNX] ๋”ฅ๋Ÿฌ๋‹ ๋ชจ๋ธ ONNX Runtime์œผ๋กœ CPU ํ™˜๊ฒฝ์—์„œ ๊ฐ€์†ํ™”ํ•˜๊ธฐ
  • [Model Inference] Pytorch 2.0 Compile ์‚ฌ์šฉ ํ›„๊ธฐ ๋ฐ ์žฅ๋‹จ์  | pytorch compile ๋ชจ๋ธ ์ถ”๋ก  ์†๋„ ๊ฐœ์„  ํ…Œ์ŠคํŠธ
  • [Model Inference] Torch-TensorRT ์‚ฌ์šฉ๋ฒ• | ๋”ฅ๋Ÿฌ๋‹ ๋ชจ๋ธ ์ตœ์ ํ™” ๋ฐ ์ธํผ๋Ÿฐ์Šค ๊ฐ€์†ํ™”
๋ญ…์ฆค
๋ญ…์ฆค
AI ๊ธฐ์ˆ  ๋ธ”๋กœ๊ทธ
    ๋ฐ˜์‘ํ˜•
  • ๋ญ…์ฆค
    CV DOODLE
    ๋ญ…์ฆค
  • ์ „์ฒด
    ์˜ค๋Š˜
    ์–ด์ œ
  • ๊ณต์ง€์‚ฌํ•ญ

    • โœจ About Me
    • ๋ถ„๋ฅ˜ ์ „์ฒด๋ณด๊ธฐ (198)
      • ๐Ÿ“– Fundamentals (33)
        • Computer Vision (9)
        • 3D vision & Graphics (6)
        • AI & ML (15)
        • NLP (2)
        • etc. (1)
      • ๐Ÿ› Research (64)
        • Deep Learning (7)
        • Image Classification (2)
        • Detection & Segmentation (17)
        • OCR (7)
        • Multi-modal (4)
        • Generative AI (6)
        • 3D Vision (2)
        • Material & Texture Recognit.. (8)
        • NLP & LLM (11)
        • etc. (0)
      • ๐ŸŒŸ AI & ML Tech (7)
        • AI & ML ์ธ์‚ฌ์ดํŠธ (7)
      • ๐Ÿ’ป Programming (85)
        • Python (18)
        • Computer Vision (12)
        • LLM (4)
        • AI & ML (17)
        • Database (3)
        • Apache Airflow (6)
        • Docker & Kubernetes (14)
        • ์ฝ”๋”ฉ ํ…Œ์ŠคํŠธ (4)
        • C++ (1)
        • etc. (6)
      • ๐Ÿ’ฌ ETC (3)
        • ์ฑ… ๋ฆฌ๋ทฐ (3)
  • ๋งํฌ

  • ์ธ๊ธฐ ๊ธ€

  • ํƒœ๊ทธ

    CNN
    OCR
    ๊ฐ์ฒด ๊ฒ€์ถœ
    multi-modal
    material recognition
    deep learning
    segmentation
    VLP
    3D Vision
    OpenAI
    Text recognition
    ๋”ฅ๋Ÿฌ๋‹
    ํŒŒ์ด์ฌ
    ChatGPT
    OpenCV
    ํ”„๋กฌํ”„ํŠธ์—”์ง€๋‹ˆ์–ด๋ง
    ๋„์ปค
    Computer Vision
    Image Classification
    AI
    airflow
    pytorch
    LLM
    pandas
    GPT
    object detection
    Python
    nlp
    ์ปดํ“จํ„ฐ๋น„์ „
    ๊ฐ์ฒด๊ฒ€์ถœ
  • ์ตœ๊ทผ ๋Œ“๊ธ€

  • ์ตœ๊ทผ ๊ธ€

  • hELLOยท Designed By์ •์ƒ์šฐ.v4.10.3
๋ญ…์ฆค
[pytorch] ๋ชจ๋ธ ์ผ๋ถ€๋ถ„๋งŒ ์ €์žฅํ•˜๊ธฐ/๋ถˆ๋Ÿฌ์˜ค๊ธฐ
์ƒ๋‹จ์œผ๋กœ

ํ‹ฐ์Šคํ† ๋ฆฌํˆด๋ฐ”