[pytorch] DataParallel 로 학습한 모델 load
·
💻 Programming/AI & ML
model = custom_LSTM() model = torch.nn.DataParallel(model) with open(os.path.join('C:/Users/' + 'model_1.pt'), 'rb') as f: model.load_state_dict(torch.load(f)) DataParallel 로 학습시킨 모델을 load해서 사용할 때는 위와 같이 torch.nn.DataParallel(model) 코드를 써줘야 error 없이 사용 가능하다.