๐ Theory/AI & ML
[DL] ๋ฅ๋ฌ๋์์์ Regularization : Weight Decay, Batch Normalization, Early Stopping
๋ญ
์ฆค
2022. 3. 23. 22:39
๋ฐ์ํ
๋ฅ๋ฌ๋์์ Regularization์ ๋ชจ๋ธ์ overfitting์ ๋ฐฉ์งํ๊ธฐ ์ํด ํน์ ํ ๊ฒ์ ๊ท์ ๋ฅผ ํ๋ ๋ฐฉ๋ฒ๋ค์ ์ด์นญํ๊ณ , ๋ํ์ ์ผ๋ก ์๋์ ๊ฐ์ ๋ฐฉ๋ฒ๋ค์ด ์๋ค.
*Overfitting : ๊ธฐ๊ณ ํ์ต ๋ชจ๋ธ์์ ์์ฃผ ๋ฐ์ํ๋ ๋ฌธ์ ์ค ํ๋๋ก, ๋ชจ๋ธ์ด ํ์ต ๋ฐ์ดํฐ์ ์ ๊ณผ๋ํ๊ฒ fit๋์ด ์ผ๋ฐํ ์ฑ๋ฅ์ด ๋จ์ด์ง๋ ํ์.
- Weight Decay - L1, L2
- Batch Normalization
- Early Stopping
Weight Decay
- Neural network์ ํน์ weight๊ฐ ๋๋ฌด ์ปค์ง๋ ๊ฒ์ ๋ชจ๋ธ์ ์ผ๋ฐํ ์ฑ๋ฅ์ ๋จ์ด๋จ๋ ค overfitting ๋๊ฒ ํ๋ฏ๋ก, weight์ ๊ท์ ๋ฅผ ๊ฑธ์ด์ฃผ๋ ๊ฒ์ด ํ์.
- L1 regularization, L2 regularization ๋ชจ๋ ๊ธฐ์กด Loss function์ weight์ ํฌ๊ธฐ๋ฅผ ํฌํจํ์ฌ weight์ ํฌ๊ธฐ๊ฐ ์์์ง๋ ๋ฐฉํฅ์ผ๋ก ํ์ตํ๋๋ก ๊ท์
L1 Regularization vs L2 Regularization
- L1 Regularization : weight ์ ๋ฐ์ดํธ ์ weight์ ํฌ๊ธฐ์ ๊ด๊ณ์์ด ์์๊ฐ์ ๋นผ๊ฒ ๋๋ฏ๋ก(loss function ๋ฏธ๋ถํ๋ฉด ํ์ธ ๊ฐ๋ฅ) ์์ weight ๋ค์ 0์ผ๋ก ์๋ ดํ๊ณ , ๋ช๋ช ์ค์ํ weight ๋ค๋ง ๋จ์. ๋ช ๊ฐ์ ์๋ฏธ์๋ ๊ฐ์ ์ฐ์ถํ๊ณ ์ถ์ sparse model ๊ฐ์ ๊ฒฝ์ฐ์ L1 Regularization์ด ํจ๊ณผ์ . ๋ค๋ง ์๋ ๊ทธ๋ฆผ์์ ๋ณด๋ฏ์ด ๋ฏธ๋ถ ๋ถ๊ฐ๋ฅํ ์ง์ ์ด ์๊ธฐ ๋๋ฌธ์ gradient-base learning ์์๋ ์ฃผ์๊ฐ ํ์.
- L2 Regularization : weight ์ ๋ฐ์ดํธ ์ weight์ ํฌ๊ธฐ๊ฐ ์ง์ ์ ์ธ ์ํฅ์ ๋ผ์ณ weight decay์ ๋์ฑ ํจ๊ณผ์
Batch Normalization
- Gradient vanishing/exploding ์ ๋ฐฉ์งํ๊ธฐ ์ํด ํ์ต ๊ณผ์ ์์ฒด๋ฅผ ์์ ํ์ํค๊ธฐ ์ํ ๋ฐฉ๋ฒ
- ํ์ต์ ๋คํธ์ํฌ์ ๊ฐ layer ๋๋ activation ๋ง๋ค ์ ๋ ฅ ๊ฐ์ ๋ถํฌ๊ฐ ๋ฌ๋ผ์ง๋ "Internal Covariance Shift" ๊ฐ ๋ฐ์ํ๊ณ ์ด๋ฅผ ํด๊ฒฐํ๊ธฐ ์ํด ์ ๋ ฅ๊ฐ์ ๋ถํฌ๋ฅผ ์กฐ์
- ํ๊ท ๊ณผ ๋ถ์ฐ์ ์กฐ์ ํ๋ ๊ณผ์ ์ด neural network ๋ด๋ถ์ ํฌํจ๋์ด ํ์ต์ batch์ ํ๊ท ๊ณผ ๋ถ์ฐ์ ์ด์ฉํ์ฌ ์ ๊ทํ
- scale๊ณผ shift(bias)๋ฅผ ๊ฐ๋ง, ๋ฒ ํ ๊ฐ์ผ๋ก ์กฐ์
- Inference ์์๋ ๋ฐฐ์น ๋จ์์ ํ๊ท ๊ณผ ๋ถ์ฐ์ ๊ตฌํ ์ ์๊ธฐ ๋๋ฌธ์ ํ์ต ๋จ๊ณ์์ moving average ๋๋ exponential average๋ฅผ ์ด์ฉํ์ฌ ๊ณ์ฐํ ํ๊ท ๊ณผ ๋ถ์ฐ์ ๊ณ ์ ๊ฐ์ผ๋ก ์ฌ์ฉ
Batch Normalization ํจ๊ณผ
- Gradient vanishing/exploding ์ ์ํํ๋ฏ๋ก ๋์ learning rate ์ฌ์ฉํ์ฌ ํ์ต ์๋ ํฅ์
- Careful weight initialization์ผ๋ก ๋ถํฐ ์์ ๋ก์์ง
- Regularization ํจ๊ณผ : BN ๊ณผ์ ์ผ๋ก ํ๊ท ๊ณผ ๋ถ์ฐ์ด ์ง์์ ์ผ๋ก ๋ณํ๊ณ weight ์ ๋ฐ์ดํธ์๋ ์ํฅ์ ์ฃผ์ด ํ๋์ weight ๊ฐ ๋งค์ฐ ์ปค์ง๋ ๊ฒ์ ๋ฐฉ์ง.
Batch Normalization ์ฃผ์ ์ฌํญ
- Batch size ๊ฐ ๋๋ฌด ํฌ๊ฑฐ๋ ์์ผ๋ฉด ํจ๊ณผ๋ฅผ ๊ธฐ๋ํ๊ธฐ ์ด๋ ค์
- ์ฌ์ฉ ์์ : Convolution - BN - Activation - Pooling - ... (BN์ ๋ชฉ์ ์ด ๋คํธ์ํฌ ์ฐ์ฐ ๊ฒฐ๊ณผ๊ฐ ์ํ๋ ๋ฐฉํฅ์ ๋ถํฌ๋๋ก ๋์ค๊ฒ ํ๋ ๊ฒ์ด๋ฏ๋ก conv ์ฐ์ฐ ๋ฐ๋ก ๋ค์ ์ฃผ๋ก ์ฌ์ฉ/ ์๋ ๊ฒฝ์ฐ๋ ์์ต๋๋ค.)
- Multi GPU training ์ ์ฃผ๋ก "Synchronized Batch Normalization" ์ฌ์ฉ
Early Stopping
- Deep Neural Network๋ ์ผ๋ฐ์ ์ผ๋ก ํ์ต์ ๋๋ฌด ๋ง์ดํ๋ฉด ํน์ epoch ์ดํ์๋ overftting์ด ๋ฐ์ํ์ฌ test ์ฑ๋ฅ ํ๋ฝ
- ์ด๋ฅผ ๋ฐฉ์งํ๊ธฐ ์ํด validation set์ ์ด์ฉํ๋ ๋ฑ์ ๋ฐฉ๋ฒ์ผ๋ก overfitting์ด ๋ฐ์ํ๊ธฐ ์ ์ ํ์ต์ ์ข ๋ฃํ๋ ๋ฐฉ๋ฒ
๋ฐ์ํ