pytorch๋ก ๋ฅ๋ฌ๋ ๋ชจ๋ธ์ ์ ์ฅํ๊ณ ๋ถ๋ฌ์ฌ ๋ Missing key(s) in state_dict ๋ผ๋ ๋ฐํ์ ์๋ฌ๊ฐ ๋ฐ์ํ๋ ๊ฒฝ์ฐ๊ฐ ์ข ์ข ์๋ค.
๋๋ถ๋ถ ๋ฅ๋ฌ๋ ๋ชจ๋ธ๊ณผ ๋ถ๋ฌ์จ ๋ชจ๋ธ ์จ์ดํธ์ ํค๊ฐ์ด ๋ง์ง ์์์ ๋ฐ์ํ๋ ์ค๋ฅ์ธ๋ฐ, ๋ชจ๋ธ๊ณผ ๋ชจ๋ธ ์จ์ดํธ๊ฐ ์๋ก ์ง์ด ์๋ ๊ฒฝ์ฐ์ ๋ฐ์ํ๊ณ ๊ฐํน ๊ทธ๋ ์ง ์์ ๊ฒฝ์ฐ์๋ ํด๋น ์๋ฌ๊ฐ ๋ฐ์ํด์ ์ฐ๋ฆฌ๋ฅผ ๊ดด๋กญํ๋ค... ใ
๊ทธ ๋ ์๋์ ๊ฐ์ด torch.load๋ก ๋ชจ๋ธ์ state dict๋ฅผ ๋ถ๋ฌ์ค๊ณ ๋๋ฒ๊น ์ ํด์ state dict์ ํค๊ฐ์ ํ์ธํด๋ณด๋ฉด ์ ํํ๊ฒ ๋ฌธ์ ๋ฅผ ํ์ ํ ์ ์๋ค.
model_path = './model.pth'
model_state_dict = torch.load(model_path)
์ ์ฒซ ๋ฒ์งธ ์บก์ฒ๋ฅผ ๋ณด๋ฉด ๋ชจ๋ธ state dict์ ๋ชจ๋ธ ์จ์ดํธ์ ํ์ต ์ ๋ณด์ธ iter, optimizer, scaler ๋ฑ์ ํค๊ฐ์ด ํจ๊ป ์ ์ฅ๋์ด ์๋ ๊ฒ์ ๋ณผ ์ ์๋ค. ๊ทธ๋์ ๋ชจ๋ธ์ ๋ฐ๋ก ๋ถ๋ฌ์ค๋ฉด Missing key(s) ์๋ฌ๊ฐ ๋ฐ์ํ๋ค. ์ด ๋ ๋ชจ๋ธ์ ์จ์ดํธ์ ํด๋นํ๋ ํค๊ฐ์ผ๋ก ์ธ๋ฑ์ฑํด์ฃผ๊ณ ๋ชจ๋ธ state dict๋ฅผ ๋ถ๋ฌ์ค๋ฉด ์ ์์ ์ผ๋ก ๋ชจ๋ ํค๊ฐ ๋งค์น๋๋ค.
๊ผญ ์ด๋ฐ ๊ฒฝ์ฐ๊ฐ ์๋๋๋ผ๋ ๋ถ๋ฌ์จ ๋ชจ๋ธ state dict์ ํค๊ฐ๋ค์ ํ์ธํ๋ฉด ์์ ํ ๋ค๋ฅธ ๋ชจ๋ธ์ ์จ์ดํธ๋ฅผ ๊ฐ์ ธ์๋์ง, ํค๊ฐ๋ง ์กฐ๊ธ ๋ค๋ฅธ ๋ถ๋ถ์ด ์๋์ง ๋ฑ์ ๋ฌธ์ ๋ฅผ ํ์ ํ ์ ์๋ค.