๋ณธ๋ฌธ ๋ฐ”๋กœ๊ฐ€๊ธฐ
๐Ÿ’ป Programming/AI & ML

[pytorch] pytorch ๋ชจ๋ธ ๋กœ๋“œ ์ค‘ Missing key(s) in state_dict ์—๋Ÿฌ

by ๋ญ…์ฆค 2022. 12. 15.
๋ฐ˜์‘ํ˜•

pytorch๋กœ ๋”ฅ๋Ÿฌ๋‹ ๋ชจ๋ธ์„ ์ €์žฅํ•˜๊ณ  ๋ถˆ๋Ÿฌ์˜ฌ ๋•Œ Missing key(s) in state_dict ๋ผ๋Š” ๋Ÿฐํƒ€์ž„ ์—๋Ÿฌ๊ฐ€ ๋ฐœ์ƒํ•˜๋Š” ๊ฒฝ์šฐ๊ฐ€ ์ข…์ข… ์žˆ๋‹ค.

 

๋Œ€๋ถ€๋ถ„ ๋”ฅ๋Ÿฌ๋‹ ๋ชจ๋ธ๊ณผ ๋ถˆ๋Ÿฌ์˜จ ๋ชจ๋ธ ์›จ์ดํŠธ์˜ ํ‚ค๊ฐ’์ด ๋งž์ง€ ์•Š์•„์„œ ๋ฐœ์ƒํ•˜๋Š” ์˜ค๋ฅ˜์ธ๋ฐ, ๋ชจ๋ธ๊ณผ ๋ชจ๋ธ ์›จ์ดํŠธ๊ฐ€ ์„œ๋กœ ์ง์ด ์•„๋‹Œ ๊ฒฝ์šฐ์— ๋ฐœ์ƒํ•˜๊ณ  ๊ฐ„ํ˜น ๊ทธ๋ ‡์ง€ ์•Š์€ ๊ฒฝ์šฐ์—๋„ ํ•ด๋‹น ์—๋Ÿฌ๊ฐ€ ๋ฐœ์ƒํ•ด์„œ ์šฐ๋ฆฌ๋ฅผ ๊ดด๋กญํžŒ๋‹ค... ใ…Ž

 

๊ทธ ๋•Œ ์•„๋ž˜์™€ ๊ฐ™์ด torch.load๋กœ ๋ชจ๋ธ์˜ state dict๋ฅผ ๋ถˆ๋Ÿฌ์˜ค๊ณ  ๋””๋ฒ„๊น…์„ ํ•ด์„œ state dict์˜ ํ‚ค๊ฐ’์„ ํ™•์ธํ•ด๋ณด๋ฉด ์ •ํ™•ํ•˜๊ฒŒ ๋ฌธ์ œ๋ฅผ ํŒŒ์•…ํ•  ์ˆ˜ ์žˆ๋‹ค.

model_path = './model.pth'
model_state_dict = torch.load(model_path)

 

 

์œ„ ์ฒซ ๋ฒˆ์งธ ์บก์ฒ˜๋ฅผ ๋ณด๋ฉด ๋ชจ๋ธ state dict์— ๋ชจ๋ธ ์›จ์ดํŠธ์™€ ํ•™์Šต ์ •๋ณด์ธ iter, optimizer, scaler ๋“ฑ์˜ ํ‚ค๊ฐ’์ด ํ•จ๊ป˜ ์ €์žฅ๋˜์–ด ์žˆ๋Š” ๊ฒƒ์„ ๋ณผ ์ˆ˜ ์žˆ๋‹ค. ๊ทธ๋ž˜์„œ ๋ชจ๋ธ์— ๋ฐ”๋กœ ๋ถˆ๋Ÿฌ์˜ค๋ฉด Missing key(s) ์—๋Ÿฌ๊ฐ€ ๋ฐœ์ƒํ•œ๋‹ค. ์ด ๋•Œ ๋ชจ๋ธ์˜ ์›จ์ดํŠธ์— ํ•ด๋‹นํ•˜๋Š” ํ‚ค๊ฐ’์œผ๋กœ ์ธ๋ฑ์‹ฑํ•ด์ฃผ๊ณ  ๋ชจ๋ธ state dict๋ฅผ ๋ถˆ๋Ÿฌ์˜ค๋ฉด ์ •์ƒ์ ์œผ๋กœ ๋ชจ๋“  ํ‚ค๊ฐ€ ๋งค์น˜๋œ๋‹ค.

 

๊ผญ ์ด๋Ÿฐ ๊ฒฝ์šฐ๊ฐ€ ์•„๋‹ˆ๋”๋ผ๋„ ๋ถˆ๋Ÿฌ์˜จ ๋ชจ๋ธ state dict์˜ ํ‚ค๊ฐ’๋“ค์„ ํ™•์ธํ•˜๋ฉด ์™„์ „ํžˆ ๋‹ค๋ฅธ ๋ชจ๋ธ์˜ ์›จ์ดํŠธ๋ฅผ ๊ฐ€์ ธ์™”๋Š”์ง€, ํ‚ค๊ฐ’๋งŒ ์กฐ๊ธˆ ๋‹ค๋ฅธ ๋ถ€๋ถ„์ด ์žˆ๋Š”์ง€ ๋“ฑ์˜ ๋ฌธ์ œ๋ฅผ ํŒŒ์•…ํ•  ์ˆ˜ ์žˆ๋‹ค.

๋ฐ˜์‘ํ˜•