๋ณธ๋ฌธ ๋ฐ”๋กœ๊ฐ€๊ธฐ
728x90

๐Ÿ’ป Programming/AI & ML16

[pytorch] model ์— ์ ‘๊ทผํ•˜๊ธฐ, ํŠน์ • layer ๋ณ€๊ฒฝํ•˜๊ธฐ pytorch ๋ชจ๋ธ์— ์ ‘๊ทผํ•˜๊ณ  ํŠน์ • layer ๋˜๋Š” layer ๋‚ด๋ถ€์˜ ๋ชจ๋“ˆ์„ ๋ณ€๊ฒฝํ•˜๋Š” ๋ฐฉ๋ฒ•์„ ์ •๋ฆฌํ•œ๋‹ค. - ์˜ˆ์‹œ ๋ชจ๋ธ : resnet50 import torch.nn as nn import torchvision.models as models model = models.resnet50(pretrained=True) 1. self.named_parameters() for name, param in model.layer1.named_parameters(): print(name,param.shape,sep=" ") 2. self.named_children() for name,ch in model.layer1.named_children(): print("name :",name) print("child :", ch.. 2022. 1. 5.
[pytorch] Custom dataset, dataloader ๋งŒ๋“ค๊ธฐ * dataset ํด๋” ๊ตฌ์กฐ minc2500 โ”œโ”€images โ”‚ โ”œโ”€brick โ”‚ โ”‚ โ”œโ”€brick_000000.jpg โ”‚ โ”‚ โ”œโ”€brick_000001.jpg โ”‚ โ”‚ โ”œโ”€... โ”‚ โ”œโ”€carpet โ”‚ โ”‚ โ”œโ”€carpet_000000.jpg โ”‚ โ”‚ โ”œโ”€... โ”‚ โ”œโ”€... โ”‚ โ”‚ โ”œโ”€... โ”‚ โ”‚ โ”œโ”€... ... ... ... โ”œโ”€labels โ”‚ โ”œโ”€train1.txt โ”‚ โ”œโ”€train2.txt โ”‚ โ”œโ”€... โ”‚ โ”œโ”€test1.txt โ”‚ โ”œโ”€test2.txt โ”‚ โ””โ”€... import os import os.path import torch import torch.utils.data as data from PIL import Image from torchvision import transforms imp.. 2022. 1. 2.
[pytorch] DataParallel ๋กœ ํ•™์Šตํ•œ ๋ชจ๋ธ load model = custom_LSTM() model = torch.nn.DataParallel(model) with open(os.path.join('C:/Users/' + 'model_1.pt'), 'rb') as f: model.load_state_dict(torch.load(f)) DataParallel ๋กœ ํ•™์Šต์‹œํ‚จ ๋ชจ๋ธ์„ loadํ•ด์„œ ์‚ฌ์šฉํ•  ๋•Œ๋Š” ์œ„์™€ ๊ฐ™์ด torch.nn.DataParallel(model) ์ฝ”๋“œ๋ฅผ ์จ์ค˜์•ผ error ์—†์ด ์‚ฌ์šฉ ๊ฐ€๋Šฅํ•˜๋‹ค. 2021. 2. 17.
[pytorch] ๋ชจ๋ธ์˜ ์ผ๋ถ€ ๋ ˆ์ด์–ด ์›จ์ดํŠธ ์—…๋ฐ์ดํŠธ ๋ง‰๊ธฐ | model freezing (๋ชจ๋ธ ํ”„๋ฆฌ์ง•) PyTorch์—์„œ ๋ชจ๋ธ ์ผ๋ถ€๋ถ„์˜ ํ•™์Šต์„๋ง‰๊ธฐ ์œ„ํ•ด์„œ๋Š” ํ•ด๋‹น ๋ถ€๋ถ„์˜ ๋งค๊ฐœ๋ณ€์ˆ˜์— ๋Œ€ํ•ด requires_grad ์†์„ฑ์„ False๋กœ ์„ค์ •ํ•˜๋ฉด ๋œ๋‹ค. ์ด๋ฅผ ํ†ตํ•ด ๊ทธ ๋งค๊ฐœ๋ณ€์ˆ˜์— ๋Œ€ํ•œ gradient๊ฐ€ ๊ณ„์‚ฐ๋˜์ง€ ์•Š์•„ ์›จ์ดํŠธ๊ฐ€ ์—…๋ฐ์ดํŠธ๋˜์ง€ ์•Š๋Š”๋‹ค. ๊ธฐ๋ณธ์ ์œผ๋กœ model.named_parameters() ๋ฉ”์„œ๋“œ๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ name๊ณผ param์„ ์ถ”์ถœํ•˜๊ณ  ์›ํ•˜๋Š” name์˜ ๋ ˆ์ด์–ด๋งŒ requires_grad๋ฅผ ๋ณ€๊ฒฝํ•  ์ˆ˜ ์žˆ๋‹ค. ๋ชจ๋ธ ํ”„๋ฆฌ์ง• ์˜ˆ์‹œ - Resnet ์˜ˆ์‹œ import torch import torch.nn as nn import torchvision.models as models # ResNet-50 ๋ชจ๋ธ ๋ถˆ๋Ÿฌ์˜ค๊ธฐ (pre-trained) resnet50 = models.resnet50(pretrained=Tr.. 2021. 2. 17.
728x90