๋ณธ๋ฌธ ๋ฐ”๋กœ๊ฐ€๊ธฐ
๐Ÿ’ป Programming/AI & ML

[pytorch] DataParallel ๋กœ ํ•™์Šตํ•œ ๋ชจ๋ธ load

by ๋ญ…์ฆค 2021. 2. 17.
๋ฐ˜์‘ํ˜•
model = custom_LSTM()

model = torch.nn.DataParallel(model)
with open(os.path.join('C:/Users/' + 'model_1.pt'), 'rb') as f:
    model.load_state_dict(torch.load(f))

DataParallel ๋กœ ํ•™์Šต์‹œํ‚จ ๋ชจ๋ธ์„ loadํ•ด์„œ ์‚ฌ์šฉํ•  ๋•Œ๋Š” ์œ„์™€ ๊ฐ™์ด torch.nn.DataParallel(model) ์ฝ”๋“œ๋ฅผ ์จ์ค˜์•ผ error ์—†์ด ์‚ฌ์šฉ ๊ฐ€๋Šฅํ•˜๋‹ค.

 

๋ฐ˜์‘ํ˜•