๋ณธ๋ฌธ ๋ฐ”๋กœ๊ฐ€๊ธฐ
๐Ÿ› Research/Generative AI

VAE (Variational Autoencoder) ์„ค๋ช… | VAE Pytorch ์ฝ”๋“œ ์˜ˆ์‹œ

by ๋ญ…์ฆค 2024. 1. 6.
๋ฐ˜์‘ํ˜•

 

VAE (Variational Autoencoder) 

์ขŒ : Autoencoder, ์šฐ : VAE

 

 

VAE(Variational Autoencoder)๋Š” ์ƒ์„ฑ ๋ชจ๋ธ ์ค‘ ํ•˜๋‚˜๋กœ, ์ฃผ๋กœ ์ฐจ์› ์ถ•์†Œ ๋ฐ ์ƒ์„ฑ ์ž‘์—…์— ์‚ฌ์šฉ๋˜๋Š” ์‹ ๊ฒฝ๋ง ์•„ํ‚คํ…์ฒ˜์ด๋‹ค. VAE๋Š” ๋ฐ์ดํ„ฐ์˜ ์ž ์žฌ ๋ณ€์ˆ˜๋ฅผ ํ•™์Šตํ•˜๊ณ  ์ด๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ์ƒˆ๋กœ์šด ๋ฐ์ดํ„ฐ๋ฅผ ์ƒ์„ฑํ•˜๋Š” ๋ฐ ์‚ฌ์šฉ๋˜๋Š”๋ฐ, ํŠนํžˆ ์ด๋ฏธ์ง€ ๋ฐ ์Œ์„ฑ ์ƒ์„ฑ๊ณผ ๊ฐ™์€ ์‘์šฉ ๋ถ„์•ผ์—์„œ ๋„๋ฆฌ ์‚ฌ์šฉ๋˜๊ณ  ์žˆ๋‹ค. ์ด๋Ÿฌํ•œ VAE๋Š” ํฌ๊ฒŒ ์ธ์ฝ”๋”์™€ ๋””์ฝ”๋”๋ผ๋Š” ๋‘ ๋ถ€๋ถ„์œผ๋กœ ๊ตฌ์„ฑ๋˜์–ด ์žˆ๋‹ค.

 

Autoencoder(์˜คํ† ์ธ์ฝ”๋”)์™€ ํ—ท๊ฐˆ๋ฆด ์ˆ˜ ์žˆ๋Š”๋ฐ,

์˜คํ† ์ธ์ฝ”๋”๋Š” ์ธํ’‹์„ ๋˜‘๊ฐ™์ด ๋ณต์›ํ•  ์ˆ˜ ์žˆ๋Š” latent variable z๋ฅผ ๋งŒ๋“œ๋Š” ๊ฒƒ์ด ๋ชฉ์ , ์ฆ‰ ์ธ์ฝ”๋”๋ฅผ ํ•™์Šตํ•˜๋Š” ๊ฒƒ์ด ์ฃผ ๋ชฉ์ ์ด๊ณ ,

VAE์˜ ๊ฒฝ์šฐ ์ธํ’‹ x๋ฅผ ์ž˜ ํ‘œํ˜„ํ•˜๋Š” latent vector๋ฅผ ์ถ”์ถœํ•˜๊ณ , ์ด๋ฅผ ํ†ตํ•ด ์ธํ’‹ x์™€ ์œ ์‚ฌํ•˜์ง€๋งŒ ์ƒˆ๋กœ์šด ๋ฐ์ดํ„ฐ๋ฅผ ์ƒ์„ฑํ•˜๋Š” ๊ฒƒ์ด ๋ชฉ์ ์ด๊ธฐ์— ์ดˆ์ ์ด ๋””์ฝ”๋”์— ๋งž์ถฐ์ ธ ์žˆ๋‹ค.

 

 

- ์ธ์ฝ”๋” (Encoder)

  • ์ธ์ฝ”๋”๋Š” ์ž…๋ ฅ ๋ฐ์ดํ„ฐ๋ฅผ ์ฃผ์–ด์ง„ ์ž ์žฌ ๊ณต๊ฐ„์˜ ํ™•๋ฅ  ๋ถ„ํฌ๋กœ ๋งคํ•‘ํ•˜๋Š” ์—ญํ• ์„ ์ˆ˜ํ–‰
  • ์ฃผ์–ด์ง„ ์ž…๋ ฅ ๋ฐ์ดํ„ฐ x์— ๋Œ€ํ•ด, ์ธ์ฝ”๋”๋Š” ํ‰๊ท (μ)๊ณผ ํ‘œ์ค€ ํŽธ์ฐจ(σ)๋ฅผ ๊ฐ€์ง„ ์ •๊ทœ ๋ถ„ํฌ์˜ ๋งค๊ฐœ๋ณ€์ˆ˜๋ฅผ ์ถœ๋ ฅ
  • ์ž…๋ ฅ ๋ฐ์ดํ„ฐ x๋ฅผ ์ž ์žฌ ๋ณ€์ˆ˜ z๋กœ ๋ณ€ํ™˜ํ•˜๋Š” ํ•จ์ˆ˜๋กœ๋„ ๋ณผ ์ˆ˜ ์žˆ์Œ

 

- ์ž ์žฌ ๋ณ€์ˆ˜ (Latent Variable)

  • latent variable z๋Š” ๋ฐ์ดํ„ฐ๋ฅผ ํŠน์ •ํ•œ ํŠน์„ฑ์ด๋‚˜ ์š”์•ฝ๋œ ํ˜•ํƒœ๋กœ ํ‘œํ˜„ํ•˜๋Š”๋ฐ ์‚ฌ์šฉ๋˜๋Š” ๋ณ€์ˆ˜์ด๋‹ค.
  • ์ด ๋ณ€์ˆ˜๋Š” ์ธ์ฝ”๋”์— ์˜ํ•ด ํ•™์Šต๋˜๋ฉฐ, ๋ณดํ†ต ํ‰๊ท (μ)๊ณผ ํ‘œ์ค€ ํŽธ์ฐจ(σ)๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ์ •์˜๋œ๋‹ค.
  • latent variable์€ ํ‘œ์ค€ ์ •๊ทœ ๋ถ„ํฌ์—์„œ ์ƒ˜ํ”Œ๋ง๋œ๋‹ค.
  • ์˜คํ† ์ธ์ฝ”๋”๋Š” ์ธํ’‹๊ณผ ์•„์›ƒํ’‹์ด ํ•ญ์ƒ ๊ฐ™๋„๋ก ํ•˜๋Š” ๊ฒƒ์ด ๋ชฉ์ ์ด๋ผ๋ฉด, VAE๋Š” ๊ทธ๋ ‡์ง€ ์•Š๊ธฐ ๋•Œ๋ฌธ์— ๋…ธ์ด์ฆˆ๋ฅผ ์ƒ˜ํ”Œ๋งํ•˜์—ฌ ์ด๋ฅผ ํ†ตํ•ด ์ž ์žฌ ๋ณ€์ˆ˜๋ฅผ ๋งŒ๋“ ๋‹ค. 
  • ์˜คํ† ์ธ์ฝ”๋”์—์„œ ์ž ์žฌ ๋ณ€์ˆ˜๊ฐ€ ํ•˜๋‚˜์˜ ๊ฐ’์ด๋ผ๋ฉด, VAE์—์„œ์˜ ์ž ์žฌ ๋ณ€์ˆ˜๋Š” ๊ฐ€์šฐ์‹œ์•ˆ ํ™•๋ฅ  ๋ถ„ํฌ์— ๊ธฐ๋ฐ˜ํ•œ ํ™•๋ฅ ๊ฐ’

 

- ๋””์ฝ”๋” (Decoder)

  • ๋””์ฝ”๋”๋Š” ์ž ์žฌ ๋ณ€์ˆ˜ z๋ฅผ ์ž…๋ ฅ์œผ๋กœ ๋ฐ›์•„ ์›๋ž˜์˜ ๋ฐ์ดํ„ฐ x๋ฅผ ๋ณต์›ํ•˜๋Š” ์—ญํ• ์„ ์ˆ˜ํ–‰
  • ๋””์ฝ”๋”๋Š” ์ž ์žฌ ๋ณ€์ˆ˜๋กœ๋ถ€ํ„ฐ ์ƒ์„ฑ๋œ ๋ฐ์ดํ„ฐ์˜ ๋ถ„ํฌ๋ฅผ ํ•™์Šต
  • ์ž ์žฌ ๋ณ€์ˆ˜๋ฅผ ์ž…๋ ฅ์œผ๋กœ ๋ฐ›์•„, ๋””์ฝ”๋”๋Š” ์ž ์žฌ ๋ณ€์ˆ˜์—์„œ ์›๋ž˜ ์ž…๋ ฅ ๋ฐ์ดํ„ฐ๋ฅผ ์ƒ์„ฑํ•˜๋Š” ๋ถ„ํฌ์˜ ๋งค๊ฐœ๋ณ€์ˆ˜๋ฅผ ์ถœ๋ ฅ

 

- ์žฌ๊ตฌ์„ฑ ์†์‹ค (Reconstruction Loss):

  • VAE์˜ ํ•™์Šต์€ ์žฌ๊ตฌ์„ฑ ์†์‹ค์„ ์ตœ์†Œํ™”ํ•˜๋„๋ก ์ด๋ฃจ์–ด์ง
  • ์žฌ๊ตฌ์„ฑ ์†์‹ค์€ ๋””์ฝ”๋”๊ฐ€ ์ž ์žฌ ๋ณ€์ˆ˜๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ์ž…๋ ฅ ๋ฐ์ดํ„ฐ๋ฅผ ์–ผ๋งˆ๋‚˜ ์ž˜ ์žฌ๊ตฌ์„ฑํ•˜๋Š”์ง€๋ฅผ ์ธก์ •
  • ์ด ์†์‹ค์€ ์ฃผ๋กœ ํ‰๊ท  ์ œ๊ณฑ ์˜ค์ฐจ(Mean Squared Error)๋‚˜ ๊ต์ฐจ ์—”ํŠธ๋กœํ”ผ ์†์‹ค(Cross-Entropy Loss)๋กœ ๊ณ„์‚ฐ ๋จ

 

- KL ๋ฐœ์‚ฐ (KL Divergence)

  • VAE์—์„œ๋Š” ์ž ์žฌ ๋ณ€์ˆ˜์˜ ๋ถ„ํฌ๊ฐ€ ํ‘œ์ค€ ์ •๊ทœ ๋ถ„ํฌ์™€ ์œ ์‚ฌํ•˜๋„๋ก ์œ ๋„ํ•˜๋Š” ์ถ”๊ฐ€์ ์ธ ํ•ญ์ธ KL divergence๊ฐ€ ์กด์žฌ
  • ์ด ํ•ญ์€ ๋ชจ๋ธ์ด ํ•™์Šต๋œ ๋ถ„ํฌ์™€ ์›๋ž˜์˜ ์ •๊ทœ ๋ถ„ํฌ ๊ฐ„์˜ ์ฐจ์ด๋ฅผ ์ธก์ •ํ•˜๊ณ , latent variable์ด ๊ณ ๋ฅด๊ฒŒ ๋ถ„ํฌํ•˜๋„๋ก ํ•œ๋‹ค.

# VAE ๋ชจ๋ธ pytorch ์ฝ”๋“œ ์˜ˆ์‹œ

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torch.autograd import Variable

# VAE ๋ชจ๋ธ ์ •์˜
class VAE(nn.Module):
    def __init__(self, input_size, hidden_size, latent_size):
        super(VAE, self).__init__()

        # ์ธ์ฝ”๋” ์ •์˜
        self.encoder = nn.Sequential(
            nn.Linear(input_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, latent_size * 2)  # mu์™€ logvar์„ ๋™์‹œ์— ์ถœ๋ ฅ
        )

        # ๋””์ฝ”๋” ์ •์˜
        self.decoder = nn.Sequential(
            nn.Linear(latent_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, input_size),
            nn.Sigmoid()  # ์ด๋ฏธ์ง€ ์ƒ์„ฑ์„ ์œ„ํ•ด Sigmoid ์‚ฌ์šฉ
        )

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def forward(self, x):
        # ์ธ์ฝ”๋”๋ฅผ ํ†ตํ•ด mu์™€ logvar ๊ณ„์‚ฐ
        enc_output = self.encoder(x)
        mu, logvar = enc_output[:, :latent_size], enc_output[:, latent_size:]

        # ๋ฆฌํŒŒ๋ผ๋ฏธํ„ฐํ™” ํŠธ๋ฆญ์„ ์‚ฌ์šฉํ•˜์—ฌ ์ž ์žฌ ๋ณ€์ˆ˜ ์ƒ˜ํ”Œ๋ง
        z = self.reparameterize(mu, logvar)

        # ๋””์ฝ”๋”๋ฅผ ํ†ตํ•ด ์ž ์žฌ ๋ณ€์ˆ˜๋กœ๋ถ€ํ„ฐ ์ด๋ฏธ์ง€ ์ƒ์„ฑ
        recon_x = self.decoder(z)

        return recon_x, mu, logvar
  • ์ฃผ๋ชฉํ•  ๋งŒํ•œ ๋ถ€๋ถ„์€ latent vector์˜ ๋ฆฌํŒŒ๋ผ๋ฏธํ„ฐํ™”(reparameterization) ๊ณผ์ •์œผ๋กœ ํ‰๊ท (mu)๊ณผ ๋กœ๊ทธ ๋ถ„์‚ฐ(logvar)์„ ์‚ฌ์šฉํ•˜์—ฌ latent vector๋ฅผ ์ƒ˜ํ”Œ๋ง
  • std = torch.exp(0.5 * logvar) : ๋กœ๊ทธ ๋ถ„์‚ฐ์œผ๋กœ ํ‘œ์ค€ ํŽธ์ฐจ๋ฅผ ๊ณ„์‚ฐ
  • eps = torch.randn_like(std) : ํ‘œ์ค€ ์ •๊ทœ ๋ถ„ํฌ์—์„œ ์ƒ˜ํ”Œ๋ง๋œ ๋žœ๋ค ๋…ธ์ด์ฆˆ๋ฅผ ์ƒ์„ฑ. torch.randn_like() ํ•จ์ˆ˜๋Š” ์ฃผ์–ด์ง„ ํ…์„œ์™€ ๊ฐ™์€ ํฌ๊ธฐ์˜ ํ‘œ์ค€ ์ •๊ทœ ๋ถ„ํฌ์—์„œ ์ƒ˜ํ”Œ๋ง๋œ ๋žœ๋ค๊ฐ’์„ ์ƒ์„ฑ ํ•จ
  • mu + eps * std : ์ตœ์ข…์ ์œผ๋กœ latent vector๋ฅผ ๊ณ„์‚ฐ. ์ด๋ ‡๊ฒŒ ๊ตฌํ•œ latent vector๋Š” ๋””์ฝ”๋”์˜ ์ž…๋ ฅ์œผ๋กœ ์‚ฌ์šฉ

 

 

# VAE Loss ๊ณ„์‚ฐ

# Loss ๊ณ„์‚ฐ
BCE = criterion(recon_data, data)
KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())

loss = BCE + KLD

VAE์˜ loss๋Š” Reconstruction loss์™€ KL divergence๋กœ ๊ตฌ์„ฑ

์ „์ฒด ์†์‹ค์€ ์žฌ๊ตฌ์„ฑ ์†์‹ค๊ณผ KL ๋ฐœ์‚ฐ์˜ ํ•ฉ์œผ๋กœ ๊ตฌ์„ฑ → loss = BCE + KLD. 

 

1. Reconstruction Loss 

  • criterion = nn.BCELoss(reduction='sum'): ์—ฌ๊ธฐ์„œ๋Š” ์ด์ง„ ๊ต์ฐจ ์—”ํŠธ๋กœํ”ผ ์†์‹ค ํ•จ์ˆ˜๋ฅผ ์‚ฌ์šฉํ•˜๊ณ , 'sum' reduction์„ ์„ ํƒ (๋ชจ๋“  ํ”ฝ์…€์— ๋Œ€ํ•œ loss์˜ ํ•ฉ๊ณ„๋ฅผ ๊ตฌํ•˜๋Š” ๋ฐฉ์‹)
  • BCE = criterion(recon_data, data): Reconstruction Loss๋Š” ์ƒ์„ฑ๋œ(recon_data) ์ด๋ฏธ์ง€์™€ ์›๋ณธ(data) ์ด๋ฏธ์ง€ ๊ฐ„์˜ ์ด์ง„ ๊ต์ฐจ ์—”ํŠธ๋กœํ”ผ๋ฅผ ์ธก์ •
  • ๋ชฉํ‘œ๋Š” ์ƒ์„ฑ๋œ ์ด๋ฏธ์ง€๊ฐ€ ์›๋ณธ ์ด๋ฏธ์ง€์™€ ์œ ์‚ฌํ•˜๊ฒŒ ๋˜๋„๋ก ํ•˜๋Š” ๊ฒƒ

 

2. KL Divergence Loss

  • KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()): KL ๋ฐœ์‚ฐ์€ latent variable์˜ ๋ถ„ํฌ๊ฐ€ ์ •๊ทœ ๋ถ„ํฌ์™€ ์–ผ๋งˆ๋‚˜ ์ฐจ์ด๋‚˜๋Š”์ง€๋ฅผ ์ธก์ •
  • ์‹์—์„œ mu๋Š” ์ธ์ฝ”๋”์—์„œ ๋‚˜์˜จ latent variable์˜ ํ‰๊ท , logvar๋Š” ๋กœ๊ทธ ๋ถ„์‚ฐ
  • ์ด ๋ถ€๋ถ„์€ VAE์˜ ํ•ต์‹ฌ์ด๋ฉฐ, ๋ชจ๋ธ์ด latent variable๋ฅผ ํ‘œ์ค€ ์ •๊ทœ ๋ถ„ํฌ์— ๊ฐ€๊น๊ฒŒ ์œ ๋„ํ•˜๋„๋ก ํ•œ๋‹ค.
  • ์ด๋Š” ๋ชจ๋ธ์ด ํ›ˆ๋ จ ์ค‘์— ๋” ์•ˆ์ •์ ์œผ๋กœ ํ•™์Šต๋˜๋„๋ก ํ•œ๋‹ค.
๋ฐ˜์‘ํ˜•