VAE (Variational Autoencoder)
VAE(Variational Autoencoder)๋ ์์ฑ ๋ชจ๋ธ ์ค ํ๋๋ก, ์ฃผ๋ก ์ฐจ์ ์ถ์ ๋ฐ ์์ฑ ์์
์ ์ฌ์ฉ๋๋ ์ ๊ฒฝ๋ง ์ํคํ
์ฒ์ด๋ค. VAE๋ ๋ฐ์ดํฐ์ ์ ์ฌ ๋ณ์๋ฅผ ํ์ตํ๊ณ ์ด๋ฅผ ์ฌ์ฉํ์ฌ ์๋ก์ด ๋ฐ์ดํฐ๋ฅผ ์์ฑํ๋ ๋ฐ ์ฌ์ฉ๋๋๋ฐ, ํนํ ์ด๋ฏธ์ง ๋ฐ ์์ฑ ์์ฑ๊ณผ ๊ฐ์ ์์ฉ ๋ถ์ผ์์ ๋๋ฆฌ ์ฌ์ฉ๋๊ณ ์๋ค. ์ด๋ฌํ VAE๋ ํฌ๊ฒ ์ธ์ฝ๋์ ๋์ฝ๋๋ผ๋ ๋ ๋ถ๋ถ์ผ๋ก ๊ตฌ์ฑ๋์ด ์๋ค.
Autoencoder(์คํ ์ธ์ฝ๋)์ ํท๊ฐ๋ฆด ์ ์๋๋ฐ,
์คํ ์ธ์ฝ๋๋ ์ธํ์ ๋๊ฐ์ด ๋ณต์ํ ์ ์๋ latent variable z๋ฅผ ๋ง๋๋ ๊ฒ์ด ๋ชฉ์ , ์ฆ ์ธ์ฝ๋๋ฅผ ํ์ตํ๋ ๊ฒ์ด ์ฃผ ๋ชฉ์ ์ด๊ณ ,
VAE์ ๊ฒฝ์ฐ ์ธํ x๋ฅผ ์ ํํํ๋ latent vector๋ฅผ ์ถ์ถํ๊ณ , ์ด๋ฅผ ํตํด ์ธํ x์ ์ ์ฌํ์ง๋ง ์๋ก์ด ๋ฐ์ดํฐ๋ฅผ ์์ฑํ๋ ๊ฒ์ด ๋ชฉ์ ์ด๊ธฐ์ ์ด์ ์ด ๋์ฝ๋์ ๋ง์ถฐ์ ธ ์๋ค.
- ์ธ์ฝ๋ (Encoder)
- ์ธ์ฝ๋๋ ์ ๋ ฅ ๋ฐ์ดํฐ๋ฅผ ์ฃผ์ด์ง ์ ์ฌ ๊ณต๊ฐ์ ํ๋ฅ ๋ถํฌ๋ก ๋งคํํ๋ ์ญํ ์ ์ํ
- ์ฃผ์ด์ง ์ ๋ ฅ ๋ฐ์ดํฐ x์ ๋ํด, ์ธ์ฝ๋๋ ํ๊ท (μ)๊ณผ ํ์ค ํธ์ฐจ(σ)๋ฅผ ๊ฐ์ง ์ ๊ท ๋ถํฌ์ ๋งค๊ฐ๋ณ์๋ฅผ ์ถ๋ ฅ
- ์ ๋ ฅ ๋ฐ์ดํฐ x๋ฅผ ์ ์ฌ ๋ณ์ z๋ก ๋ณํํ๋ ํจ์๋ก๋ ๋ณผ ์ ์์
- ์ ์ฌ ๋ณ์ (Latent Variable)
- latent variable z๋ ๋ฐ์ดํฐ๋ฅผ ํน์ ํ ํน์ฑ์ด๋ ์์ฝ๋ ํํ๋ก ํํํ๋๋ฐ ์ฌ์ฉ๋๋ ๋ณ์์ด๋ค.
- ์ด ๋ณ์๋ ์ธ์ฝ๋์ ์ํด ํ์ต๋๋ฉฐ, ๋ณดํต ํ๊ท (μ)๊ณผ ํ์ค ํธ์ฐจ(σ)๋ฅผ ์ฌ์ฉํ์ฌ ์ ์๋๋ค.
- latent variable์ ํ์ค ์ ๊ท ๋ถํฌ์์ ์ํ๋ง๋๋ค.
- ์คํ ์ธ์ฝ๋๋ ์ธํ๊ณผ ์์ํ์ด ํญ์ ๊ฐ๋๋ก ํ๋ ๊ฒ์ด ๋ชฉ์ ์ด๋ผ๋ฉด, VAE๋ ๊ทธ๋ ์ง ์๊ธฐ ๋๋ฌธ์ ๋ ธ์ด์ฆ๋ฅผ ์ํ๋งํ์ฌ ์ด๋ฅผ ํตํด ์ ์ฌ ๋ณ์๋ฅผ ๋ง๋ ๋ค.
- ์คํ ์ธ์ฝ๋์์ ์ ์ฌ ๋ณ์๊ฐ ํ๋์ ๊ฐ์ด๋ผ๋ฉด, VAE์์์ ์ ์ฌ ๋ณ์๋ ๊ฐ์ฐ์์ ํ๋ฅ ๋ถํฌ์ ๊ธฐ๋ฐํ ํ๋ฅ ๊ฐ
- ๋์ฝ๋ (Decoder)
- ๋์ฝ๋๋ ์ ์ฌ ๋ณ์ z๋ฅผ ์ ๋ ฅ์ผ๋ก ๋ฐ์ ์๋์ ๋ฐ์ดํฐ x๋ฅผ ๋ณต์ํ๋ ์ญํ ์ ์ํ
- ๋์ฝ๋๋ ์ ์ฌ ๋ณ์๋ก๋ถํฐ ์์ฑ๋ ๋ฐ์ดํฐ์ ๋ถํฌ๋ฅผ ํ์ต
- ์ ์ฌ ๋ณ์๋ฅผ ์ ๋ ฅ์ผ๋ก ๋ฐ์, ๋์ฝ๋๋ ์ ์ฌ ๋ณ์์์ ์๋ ์ ๋ ฅ ๋ฐ์ดํฐ๋ฅผ ์์ฑํ๋ ๋ถํฌ์ ๋งค๊ฐ๋ณ์๋ฅผ ์ถ๋ ฅ
- ์ฌ๊ตฌ์ฑ ์์ค (Reconstruction Loss):
- VAE์ ํ์ต์ ์ฌ๊ตฌ์ฑ ์์ค์ ์ต์ํํ๋๋ก ์ด๋ฃจ์ด์ง
- ์ฌ๊ตฌ์ฑ ์์ค์ ๋์ฝ๋๊ฐ ์ ์ฌ ๋ณ์๋ฅผ ์ฌ์ฉํ์ฌ ์ ๋ ฅ ๋ฐ์ดํฐ๋ฅผ ์ผ๋ง๋ ์ ์ฌ๊ตฌ์ฑํ๋์ง๋ฅผ ์ธก์
- ์ด ์์ค์ ์ฃผ๋ก ํ๊ท ์ ๊ณฑ ์ค์ฐจ(Mean Squared Error)๋ ๊ต์ฐจ ์ํธ๋กํผ ์์ค(Cross-Entropy Loss)๋ก ๊ณ์ฐ ๋จ
- KL ๋ฐ์ฐ (KL Divergence)
- VAE์์๋ ์ ์ฌ ๋ณ์์ ๋ถํฌ๊ฐ ํ์ค ์ ๊ท ๋ถํฌ์ ์ ์ฌํ๋๋ก ์ ๋ํ๋ ์ถ๊ฐ์ ์ธ ํญ์ธ KL divergence๊ฐ ์กด์ฌ
- ์ด ํญ์ ๋ชจ๋ธ์ด ํ์ต๋ ๋ถํฌ์ ์๋์ ์ ๊ท ๋ถํฌ ๊ฐ์ ์ฐจ์ด๋ฅผ ์ธก์ ํ๊ณ , latent variable์ด ๊ณ ๋ฅด๊ฒ ๋ถํฌํ๋๋ก ํ๋ค.
# VAE ๋ชจ๋ธ pytorch ์ฝ๋ ์์
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torch.autograd import Variable
# VAE ๋ชจ๋ธ ์ ์
class VAE(nn.Module):
def __init__(self, input_size, hidden_size, latent_size):
super(VAE, self).__init__()
# ์ธ์ฝ๋ ์ ์
self.encoder = nn.Sequential(
nn.Linear(input_size, hidden_size),
nn.ReLU(),
nn.Linear(hidden_size, latent_size * 2) # mu์ logvar์ ๋์์ ์ถ๋ ฅ
)
# ๋์ฝ๋ ์ ์
self.decoder = nn.Sequential(
nn.Linear(latent_size, hidden_size),
nn.ReLU(),
nn.Linear(hidden_size, input_size),
nn.Sigmoid() # ์ด๋ฏธ์ง ์์ฑ์ ์ํด Sigmoid ์ฌ์ฉ
)
def reparameterize(self, mu, logvar):
std = torch.exp(0.5 * logvar)
eps = torch.randn_like(std)
return mu + eps * std
def forward(self, x):
# ์ธ์ฝ๋๋ฅผ ํตํด mu์ logvar ๊ณ์ฐ
enc_output = self.encoder(x)
mu, logvar = enc_output[:, :latent_size], enc_output[:, latent_size:]
# ๋ฆฌํ๋ผ๋ฏธํฐํ ํธ๋ฆญ์ ์ฌ์ฉํ์ฌ ์ ์ฌ ๋ณ์ ์ํ๋ง
z = self.reparameterize(mu, logvar)
# ๋์ฝ๋๋ฅผ ํตํด ์ ์ฌ ๋ณ์๋ก๋ถํฐ ์ด๋ฏธ์ง ์์ฑ
recon_x = self.decoder(z)
return recon_x, mu, logvar
- ์ฃผ๋ชฉํ ๋งํ ๋ถ๋ถ์ latent vector์ ๋ฆฌํ๋ผ๋ฏธํฐํ(reparameterization) ๊ณผ์ ์ผ๋ก ํ๊ท (mu)๊ณผ ๋ก๊ทธ ๋ถ์ฐ(logvar)์ ์ฌ์ฉํ์ฌ latent vector๋ฅผ ์ํ๋ง
- std = torch.exp(0.5 * logvar) : ๋ก๊ทธ ๋ถ์ฐ์ผ๋ก ํ์ค ํธ์ฐจ๋ฅผ ๊ณ์ฐ
- eps = torch.randn_like(std) : ํ์ค ์ ๊ท ๋ถํฌ์์ ์ํ๋ง๋ ๋๋ค ๋ ธ์ด์ฆ๋ฅผ ์์ฑ. torch.randn_like() ํจ์๋ ์ฃผ์ด์ง ํ ์์ ๊ฐ์ ํฌ๊ธฐ์ ํ์ค ์ ๊ท ๋ถํฌ์์ ์ํ๋ง๋ ๋๋ค๊ฐ์ ์์ฑ ํจ
- mu + eps * std : ์ต์ข ์ ์ผ๋ก latent vector๋ฅผ ๊ณ์ฐ. ์ด๋ ๊ฒ ๊ตฌํ latent vector๋ ๋์ฝ๋์ ์ ๋ ฅ์ผ๋ก ์ฌ์ฉ
# VAE Loss ๊ณ์ฐ
# Loss ๊ณ์ฐ
BCE = criterion(recon_data, data)
KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
loss = BCE + KLD
VAE์ loss๋ Reconstruction loss์ KL divergence๋ก ๊ตฌ์ฑ
์ ์ฒด ์์ค์ ์ฌ๊ตฌ์ฑ ์์ค๊ณผ KL ๋ฐ์ฐ์ ํฉ์ผ๋ก ๊ตฌ์ฑ → loss = BCE + KLD.
1. Reconstruction Loss
- criterion = nn.BCELoss(reduction='sum'): ์ฌ๊ธฐ์๋ ์ด์ง ๊ต์ฐจ ์ํธ๋กํผ ์์ค ํจ์๋ฅผ ์ฌ์ฉํ๊ณ , 'sum' reduction์ ์ ํ (๋ชจ๋ ํฝ์ ์ ๋ํ loss์ ํฉ๊ณ๋ฅผ ๊ตฌํ๋ ๋ฐฉ์)
- BCE = criterion(recon_data, data): Reconstruction Loss๋ ์์ฑ๋(recon_data) ์ด๋ฏธ์ง์ ์๋ณธ(data) ์ด๋ฏธ์ง ๊ฐ์ ์ด์ง ๊ต์ฐจ ์ํธ๋กํผ๋ฅผ ์ธก์
- ๋ชฉํ๋ ์์ฑ๋ ์ด๋ฏธ์ง๊ฐ ์๋ณธ ์ด๋ฏธ์ง์ ์ ์ฌํ๊ฒ ๋๋๋ก ํ๋ ๊ฒ
2. KL Divergence Loss
- KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()): KL ๋ฐ์ฐ์ latent variable์ ๋ถํฌ๊ฐ ์ ๊ท ๋ถํฌ์ ์ผ๋ง๋ ์ฐจ์ด๋๋์ง๋ฅผ ์ธก์
- ์์์ mu๋ ์ธ์ฝ๋์์ ๋์จ latent variable์ ํ๊ท , logvar๋ ๋ก๊ทธ ๋ถ์ฐ
- ์ด ๋ถ๋ถ์ VAE์ ํต์ฌ์ด๋ฉฐ, ๋ชจ๋ธ์ด latent variable๋ฅผ ํ์ค ์ ๊ท ๋ถํฌ์ ๊ฐ๊น๊ฒ ์ ๋ํ๋๋ก ํ๋ค.
- ์ด๋ ๋ชจ๋ธ์ด ํ๋ จ ์ค์ ๋ ์์ ์ ์ผ๋ก ํ์ต๋๋๋ก ํ๋ค.