์์ฑ ๋ชจ๋ธ์์ Diffusion์ ๊ณ ํด์๋ ์ด๋ฏธ์ง๋ฅผ ๋ง๋ค์ด๋ด๋ ํต์ฌ ๊ธฐ์ ๋ก ์๋ฆฌ ์ก์์ง๋ง, DDPM์ฒ๋ผ ํฝ์ ๊ณต๊ฐ์์ ์ง์ ๋ ธ์ด์ฆ๋ฅผ ๋ค๋ฃจ๋ ๋ฐฉ์์๋ ์น๋ช ์ ์ธ ๋จ์ ์ด ์์๋ค. ๋ฐ๋ก ์ฐ์ฐ๊ณผ ๋ฉ๋ชจ๋ฆฌ ํจ์จ์ด๋ค.
[Gen AI] Diffusion Model๊ณผ DDPM ๊ฐ๋ ์ค๋ช
์์ฑ ๋ชจ๋ธ์์ Diffusion ๋ชจ๋ธ์ ๊ณ ํด์๋ ์ด๋ฏธ์ง๋ฅผ ์์ฑํ๋ ํต์ฌ ๊ธฐ์ ๋ก ์ฃผ๋ชฉ๋ฐ๊ณ ์๋๋ฐ, ์ด ๋ชจ๋ธ์ ๋ ธ์ด์ฆ๋ฅผ ์ ์ ์ ๊ฑฐํด๊ฐ๋ฉฐ ์ด๋ฏธ์ง๋ฅผ ์์ฑํ๋ค๋ ๊ฐ๋ ์ผ๋ก, Stable Diffusion, DALL·E 2 ๋ฑ ๋ค์ํ
mvje.tistory.com
์๋ฅผ ๋ค์ด, 256×256 ํด์๋์ ์ด๋ฏธ์ง๋ฅผ ์ง์ ๋ํจ์ (ํฝ์ ๋จ์๋ก ๋ ธ์ด์ฆ๋ฅผ ๋ฃ๊ณ ์ ๊ฑฐ)ํ๋ ค๋ฉด, ์๋ฐฑ MB์ ๋ฌํ๋ feature๋ฅผ ๋ฐ๋ณต์ ์ผ๋ก ์ฒ๋ฆฌํด์ผ ํ๋ค. ๊ณ ํด์๋์ผ์๋ก ์ด ๋ถ๋ด์ ๊ธฐํ๊ธ์์ ์ผ๋ก ์ปค์ ธ, GPU ๋ฉ๋ชจ๋ฆฌ ํ๊ณ์ ๊ธ์ธ ๋๋ฌํ๋ค. ์ด ๋ฌธ์ ๋ฅผ ํด๊ฒฐํ๊ธฐ ์ํด ๋ฑ์ฅํ ๊ฒ์ด ๋ฐ๋ก Latent Diffusion Models (LDM) ์ด๋ค.
๐กLDM training์ ํด๋ณด๊ณ ์ถ๋ค๋ฉด → LDM_MNIST
Vision-AI-Tutorials/Image_Generation/LDM_MNIST at main · ldj7672/Vision-AI-Tutorials
Computer Vision & AI๋ฅผ ์ฝ๊ฒ ๋ฐฐ์ฐ๊ณ ์ค์ตํ ์ ์๋ ์์ ๋ชจ์์ ๋๋ค. Contribute to ldj7672/Vision-AI-Tutorials development by creating an account on GitHub.
github.com
1. LDM์ด๋?
LDM(Latent Diffusion Model)์ 2022๋ CVPR์ ๋ฐํ๋ ๋ ผ๋ฌธ “High-Resolution Image Synthesis with Latent Diffusion Models”์์ ์ฒ์ ์ ์๋์๋ค. ์ด ๋ชจ๋ธ์ Stable Diffusion์ ๋ฐ๋ก ๊ทธ ์ํ์ด๋ฉฐ, ๊ธฐ์กด DDPM ๋ฐฉ์์ ๊ฐ์ฅ ํฐ ๋ณ๋ชฉ์ธ ํฝ์ ๊ณต๊ฐ์์์ ๋ํจ์ ์ latent ๊ณต๊ฐ์ผ๋ก ์ฎ๊ธฐ๋ ์ ๋ต์ ์ทจํ๋ค.
์ฆ, LDM์ ๋จผ์ ๊ณ ํด์๋ ์ด๋ฏธ์ง๋ฅผ VAE(Variational Autoencoder)๋ฅผ ํตํด ํจ์ฌ ์์ latent ๊ณต๊ฐ์ผ๋ก ์ธ์ฝ๋ฉํ๋ค. ๊ทธ ํ ์ด latent ๊ณต๊ฐ์์ DDPM๊ณผ ๋์ผํ ๋ฐฉ์์ผ๋ก ๋ ธ์ด์ฆ๋ฅผ ์ฃผ์ ํ๊ณ ์ ๊ฑฐํ๋ ๊ณผ์ ์ ํ์ตํ๋ค. ๋ง์ง๋ง์ ์ด latent๋ฅผ ๋ค์ ๋์ฝ๋ฉํด ์๋์ ์ด๋ฏธ์ง๋ก ๋ณต์ํ๋ค. ์ด๋ ๊ฒ ํ๋ฉด ๋ํจ์ ์ด ์ฒ๋ฆฌํด์ผ ํ feature map ํฌ๊ธฐ๊ฐ ์์ญ ๋ฐฐ ์์์ ธ, ํจ์ฌ ๋น ๋ฅด๊ณ ์ ์ ์์์ผ๋ก ๊ณ ํด์๋ ์ด๋ฏธ์ง๋ฅผ ์์ฑํ ์ ์๊ฒ ๋๋ค.
์๋ฅผ ๋ค์ด, 256×256×3 ํฌ๊ธฐ์ ์ด๋ฏธ์ง๋ฅผ VAE๋ก ์ธ์ฝ๋ฉํ๋ฉด 32×32×4 ์ ๋์ latent๋ก ๋ณํ๋๋ค. ์ฌ๊ธฐ์ ๋ํจ์ ์ ์ํํ๋ฉด, ์๋๋ณด๋ค ์ฐ์ฐ๋์ด ์ฝ 60๋ฐฐ ๊ฐ๊น์ด ์ค์ด๋ ๋ค. ๊ทธ ๊ฒฐ๊ณผ ์ผ๋ฐ์ ์ธ ์๋น์ GPU๋ก๋ Stable Diffusion ๊ฐ์ ๊ณ ํด์๋ ์ด๋ฏธ์ง ์์ฑ์ด ๊ฐ๋ฅํด์ก๋ค.
์ง์ง ๊ฐ๋จํ ๋งํ๋ฉด,
LDM์ <์ด๋ฏธ์ง>๋ฅผ VAE Encoder์ ํต๊ณผ์์ผ <latent vector>๋ก ๋ณํํ๊ณ , ์ด latent vector์ ๋
ธ์ด์ฆ๋ฅผ ์ฃผ์
ํ ๋ค,
Unet์ด ๊ทธ ๋
ธ์ด์ฆ๋ฅผ ์์ธกํ๋๋ก ํ์ตํ๋ ๋ชจ๋ธ์ด๋ค. ๊ทธ๋์ ์ค์ LDM์ VAE + Unet์ผ๋ก ๊ตฌ์ฑ๋๋ฉฐ, ์
๋ ฅ ๋ฐ์ดํฐ๋ ์ด๋ฏธ์ง์ด๊ณ , ์กฐ๊ฑด ์ ๋ณด๋ก๋ ํด๋์ค๋ ํ
์คํธ ์๋ฒ ๋ฉ ๋ฑ์ด ํจ๊ป ํ์ฉ๋ ์ ์๋ค.
2. LDM์ ๊ตฌ์กฐ์ ํ์ต ๊ณผ์
LDM์ ํฌ๊ฒ ์ธ ๊ฐ์ง ๋จ๊ณ๋ก ์ด๋ฃจ์ด์ง๋ค.
2.1 Encoding: ์ด๋ฏธ์ง → latent
์๋ณธ ์ด๋ฏธ์ง xโ๋ฅผ VAE Encoder๋ฅผ ํตํด latent zโ๋ก ์์ถํ๋ค.
xโ → Encoder → zโ
์ด latent zโ๋ ์ด๋ฏธ์ง์ ๊ตฌ์กฐ์ , ์๊ฐ์ ์ ๋ณด๋ฅผ ์ ์งํ์ง๋ง, ํฝ์ ๋จ์์ ๋ ธ์ด์ฆ์๋ ๋ ๋ฏผ๊ฐํ compact representation์ด๋ค. VAE ํ์ต์ diffusion ํ์ต ์ด์ ์ ๋จผ์ ์๋ฃ๋๋ฉฐ, ์ดํ์๋ Encoder์ Decoder๋ฅผ ๊ณ ์ ์์ผ ์ฌ์ฉํ๋ค.
2.2 Latent Diffusion: latent์์ noise ์ฃผ์ ๊ณผ ์์ธก
์ด์ ๊ธฐ์กด DDPM์์ ํฝ์ ๊ณต๊ฐ์์ ํ๋ ๊ฒ์ latent ๊ณต๊ฐ์์ ์ํํ๋ค.
Forward Process
zโ → zโ → zโ → ... → z_T
zโ = sqrt(αฬ_t) * zโ + sqrt(1-αฬ_t) * ε
- zโ์ ์๊ฐ step t์ ๋ฐ๋ผ ์ ์ ๋ ธ์ด์ฆ๋ฅผ ๋ํด zโ๋ฅผ ๋ง๋ ๋ค.
- ํ์ต์์๋ xโ์์ ๋ฐ๋ก zโ๋ฅผ ์ป์ ๋ค, random noise ε๊ณผ timestep t๋ฅผ ์ํ๋งํด์์ผ๋ก ํ ๋ฒ์ ๋ง๋ ๋ค.
- ์ด๋ DDPM์์์ ๋ฐฉ์๊ณผ ์์ ํ ๋์ผํ๋ค.
์๋๋ time embedding๊ณผ class embedding์ Unet์ ์ง์ ์ฃผ์ ํ๋ ํํ์ ๋จ์ํ๋ PyTorch ์์ ์ฝ๋์ด๋ค.
class LDMUNet(nn.Module):
def __init__(self, num_classes=10):
super().__init__()
# Time embedding (128 → 512)
self.time_mlp = nn.Sequential(
nn.Linear(128, 512), nn.ReLU(), nn.Linear(512, 512)
)
# Class embedding (256 → 512)
self.class_emb = nn.Embedding(num_classes + 1, 256)
self.class_mlp = nn.Sequential(
nn.Linear(256, 512), nn.ReLU(), nn.Linear(512, 512)
)
# Input projection
self.input_proj = nn.Conv2d(4, 64, 3, padding=1) # [4 → 64]
# Encoder
self.enc1 = nn.Conv2d(64 + 512, 128, 4, stride=2, padding=1) # 32→16
self.enc2 = nn.Conv2d(128 + 512, 256, 4, stride=2, padding=1) # 16→8
self.enc3 = nn.Conv2d(256 + 512, 512, 4, stride=2, padding=1) # 8→4
# Middle
self.middle1 = nn.Conv2d(512 + 512, 768, 3, padding=1)
self.middle2 = nn.Conv2d(768 + 512, 512, 3, padding=1)
# Decoder
self.dec3 = nn.ConvTranspose2d(1024 + 512, 256, 4, stride=2, padding=1) # 4→8
self.dec2 = nn.ConvTranspose2d(512 + 512, 128, 4, stride=2, padding=1) # 8→16
self.dec1 = nn.ConvTranspose2d(192 + 512, 64, 4, stride=2, padding=1) # 16→32
# Output
self.output_proj = nn.Conv2d(64 + 512, 4, 3, padding=1) # back to [4]
def forward(self, x, t, class_labels=None):
"""
- x: [batch, 4, 32, 32] (latent z_t)
- t: [batch] timestep
- class_labels: [batch] class index
"""
# Create condition embedding
t_emb = self.time_mlp(t) # [batch, 512]
if class_labels is None:
class_labels = torch.full(
(x.size(0),), self.class_emb.num_embeddings - 1,
device=x.device, dtype=torch.long
)
c_emb = self.class_mlp(self.class_emb(class_labels)) # [batch, 512]
cond_emb = (t_emb + c_emb).unsqueeze(-1).unsqueeze(-1) # [batch, 512, 1, 1]
# Input projection
x = self.input_proj(x) # [batch, 64, 32, 32]
# Encoder with condition
x = torch.cat([x, cond_emb.expand(-1, -1, 32, 32)], dim=1)
x1 = F.relu(self.enc1(x)) # [batch, 128, 16, 16]
x1_cat = torch.cat([x1, cond_emb.expand(-1, -1, 16, 16)], dim=1)
x2 = F.relu(self.enc2(x1_cat)) # [batch, 256, 8, 8]
x2_cat = torch.cat([x2, cond_emb.expand(-1, -1, 8, 8)], dim=1)
x3 = F.relu(self.enc3(x2_cat)) # [batch, 512, 4, 4]
# Middle with condition
x3_cat = torch.cat([x3, cond_emb.expand(-1, -1, 4, 4)], dim=1)
x_mid = F.relu(self.middle1(x3_cat)) # [batch, 768, 4, 4]
x_mid = torch.cat([x_mid, cond_emb.expand(-1, -1, 4, 4)], dim=1)
x_mid = F.relu(self.middle2(x_mid)) # [batch, 512, 4, 4]
# Decoder with condition
x_mid_cat = torch.cat([x_mid, x3, cond_emb.expand(-1, -1, 4, 4)], dim=1)
x = F.relu(self.dec3(x_mid_cat)) # [batch, 256, 8, 8]
x = torch.cat([x, x2, cond_emb.expand(-1, -1, 8, 8)], dim=1)
x = F.relu(self.dec2(x)) # [batch, 128, 16, 16]
x = torch.cat([x, x1, cond_emb.expand(-1, -1, 16, 16)], dim=1)
x = F.relu(self.dec1(x)) # [batch, 64, 32, 32]
# Output projection with final condition
x = torch.cat([x, cond_emb.expand(-1, -1, 32, 32)], dim=1)
return self.output_proj(x) # [batch, 4, 32, 32]
- ์ Unet ์์ ์ฝ๋๋ timestep embedding + class embedding์ ๋ํด ๋ง๋ cond_emb๋ฅผ encoder์ decoder์ ์ฃผ์ ํด, ๊ฐ ๋จ๊ณ์์ diffusion ๊ณผ์ ์ ์๊ฐ๊ณผ ์กฐ๊ฑด์ ๋ฐ์ํ๋ค.
- cond_emb๋ ์๋ 1D vector์ด์ง๋ง, ์ธ์ฝ๋/๋์ฝ๋์์ spatial feature map๊ณผ concatํ๊ธฐ ์ํด [batch, 512, H, W]๋ก broadcastํด์ ์ฌ์ฉํ๋ค.
- ์ด๋ ๊ฒ ๊ฐ๋จํ๋ ๊ตฌ์กฐ๋ก๋ LDM์ "time + class guided Unet" ๊ฐ๋ ์ ์ง๊ด์ ์ผ๋ก ์คํํ ์ ์๋ค.
Reverse Process
noise_pred = UNet(zโ, t)
L = E[||ε - noise_pred||²]
- ํ์ต ๋์์ธ UNet์ ์ด zโ์ t๋ฅผ ์ ๋ ฅ๋ฐ์, zโ์ ์์ฌ ์๋ ๋ ธ์ด์ฆ ε์ ์์ธกํ๋ค.
- ๊ทธ๋ฆฌ๊ณ ์ค์ noise ε๊ณผ MSE Loss๋ฅผ ํตํด ์ฐจ์ด๋ฅผ ์ค์ฌ๋๊ฐ๋ค.
- ์ด๋ฐ ๋ฐ๋ณต์ ํตํด, ๋ค์ํ t์์ ๋ ธ์ด์ฆ๊ฐ ํฌํจ๋ zโ๋ฅผ ๋ณด๊ณ ๊ทธ ์์ ์ด๋ค ๋ ธ์ด์ฆ๊ฐ ๋ค์ด์๋์ง ์ ์์ธกํ ์ ์๋๋ก ํ์ต๋๋ค.
2.3 Decoding: latent → ์ด๋ฏธ์ง ๋ณต์
์์ฑ์ด ๋๋ latent zโ๋ ๊ณ ์ ๋ VAE Decoder๋ฅผ ํตํด ๋ค์ ๊ณ ํด์๋ ์ด๋ฏธ์ง๋ก ๋ณต์๋๋ค.
zโ → Decoder → xฬโ
์ด ๊ณผ์ ์ ํตํด latent์์ ์ ๋ง๋ค์ด์ง ๊ตฌ์กฐ๊ฐ ๋์ฝ๋ฉ๋์ด ํฝ์ ๊ณต๊ฐ์ ์์ฐ์ค๋ฌ์ด ์ด๋ฏธ์ง๋ก ๋ฐ๋๋ค.
์ ๋ฆฌํ๋ฉด...
โ LDM ํ์ต ํ๋ก์ฐ
[์ด๋ฏธ์ง xโ]
|
v
[VAE Encoder]
|
v
[latent zโ]
|
v
[Noise injection]
zโ = sqrt(αฬ_t) * zโ + sqrt(1-αฬ_t) * ε
|
v
+-------------------+
| Unet(zโ, t, cond)|
| → noise_pred |
+-------------------+
|
v
[MSE Loss]
= || noise_pred - ε ||²
|
v
[Backpropagation]
3. Inference & DDIM Sampling
์์ฑ(Inference) ๋จ๊ณ์์๋ ์์ ํ ๋๋คํ latent noise z_T์์ ์์ํ๋ค. ๊ทธ๋ฆฌ๊ณ ํ์ต๋ UNet์ ์ด์ฉํด ์ ์ง์ ์ผ๋ก ๋ ธ์ด์ฆ๋ฅผ ์ ๊ฑฐํ๋ฉฐ z_{T-1}, z_{T-2}, ..., zโ์ผ๋ก ์ด๋ํ๋ค.
์ด๋ DDPM ๋ฐฉ์์ ์๋ฐฑ~์์ฒ ์คํ ์ ๊ฑฐ์น๋ฉฐ ์กฐ๊ธ์ฉ ๋ ธ์ด์ฆ๋ฅผ ์ค์ด๋ stochastic sampling์ ํ๋ค. ๋ฐ๋ฉด LDM์ ์ผ๋ฐ์ ์ผ๋ก DDIM(Denoising Diffusion Implicit Models) ๋ฐฉ์์ ์ฌ์ฉํด, ๊ฐ์ β ์ค์ผ์ค์ ์ ์ ์คํ ์ผ๋ก deterministicํ๊ฒ ๋ด๋ค. DDIM์ ODE(ํ๋ฅ ์ ๋ฏธ๋ถ ๋ฐฉ์ ์ → ๊ฒฐ์ ์ ๋ฏธ๋ถ ๋ฐฉ์ ์) ํด์์ ์ด์ฉํด noise sampling term์ ์ ๊ฑฐํ๊ฑฐ๋ ์กฐ์ ํจ์ผ๋ก์จ,
- ๋ ์ ์ step์์๋ ๊นจ๋ํ ๊ฒฐ๊ณผ๋ฅผ ๋ง๋ค ์ ์๊ณ ,
- ๋์ผํ z_T์์ ํญ์ ๊ฐ์ xฬโ๋ฅผ ์์ฑํ ์๋ ์๋ค.
์ฆ, DDIM์ ํ ์คํ ๋น ๋ ๋ง์ ๋ ธ์ด์ฆ๋ฅผ ์ ๊ฑฐํ๋ฉฐ ์ด๋ํ๊ณ , η ํ๋ผ๋ฏธํฐ๋ฅผ ํตํด sampling stochasticity๋ฅผ ์กฐ์ ํ ์๋ ์๋ค. ์ด ๋๋ถ์ ๋ณดํต 50~100 ์คํ ์ ๋๋ง์ผ๋ก๋ ์ข์ ๊ฒฐ๊ณผ๋ฅผ ๋ธ๋ค.
์ ๋ฆฌํ๋ฉด...
โ LDM ํ์ต ํ๋ก์ฐ
[Random latent z_T ~ N(0,I)]
|
v
for t = T...1:
+-------------------+
| Unet(zโ, t, cond)|
| → noise_pred |
+-------------------+
|
v
[Denoise step]
z_{t-1} = f(zโ, noise_pred, t)
(๋ฐ๋ณต)
|
v
[latent zโ]
|
v
[VAE Decoder]
|
v
[์ด๋ฏธ์ง xฬโ (์ํ)]
4. Text-to-Image (T2I)์ LDM
Stable Diffusion์ ์ด LDM ๊ตฌ์กฐ๋ฅผ ๊ทธ๋๋ก ๊ฐ์ ธ์, UNet์ text embedding์ Cross-Attention์ผ๋ก ์ฐ๊ฒฐํด ์กฐ๊ฑด๋ถ ์์ฑ์ ๊ฐ๋ฅํ๊ฒ ํ๋ค. ์ฆ, T2I๋ UNet์ ์ ๋ ฅ์ผ๋ก
(zโ, t, text_embedding)
์ ์ฃผ์ด,
- zโ์ t๋ฅผ ๋ณด๊ณ noise๋ฅผ ์์ธกํ๋,
- attention query-key-value์ text embedding์ ๋ฃ์ด prompt์ ๋ง๊ฒ ์ด๋ฏธ์ง๋ฅผ ๋ง๋ค์ด๊ฐ๋ค.
๊ฒฐ๊ณผ์ ์ผ๋ก “astronaut riding a horse” ๊ฐ์ ๋ฌธ์ฅ์ ์ ๋ ฅํ๋ฉด, ์ด ์กฐ๊ฑด์ ๋ง๊ฒ latent ๊ณต๊ฐ์์ ๋ ธ์ด์ฆ๋ฅผ ์ ๊ฑฐํด๋๊ฐ๋ฉด์ ์ํ๋ ์ด๋ฏธ์ง๊ฐ ํ์ํ๋ค.
LDM์ “ํฝ์
๊ณต๊ฐ ๋์ latent ๊ณต๊ฐ์์ diffusion์ ์ํ”ํ์ฌ ๋ฉ๋ชจ๋ฆฌ์ ์๋๋ฅผ ํ๊ธฐ์ ์ผ๋ก ๊ฐ์ ํ DDPM์ ํ์ฅํ์ด๋ฉฐ,
์ด๋ฅผ ํตํด Stable Diffusion ๊ฐ์ ๊ณ ํด์๋ Text-to-Image ์์ฑ์ด ๊ฐ๋ฅํด์ก๋ค. ์ด๋ฌํ ๋ฐฉ์์ ์ดํ ControlNet, Inpainting, 3D NeRF reconstruction ๋ฑ ๋ค์ํ ๋ํจ์ ๊ธฐ๋ฐ ๊ธฐ์ ์ ํ์ค์ด ๋์์ผ๋ฉฐ, ์ฌ์ ํ ๋ฉํฐ๋ชจ๋ฌ ์์ฑ(ํ
์คํธ-์ด๋ฏธ์ง-์ค๋์ค) ๋ถ์ผ์์ ํ๋ฐํ ์ฐ๊ตฌ๋๊ณ ์๋ค.
'๐ Research > Generative AI' ์นดํ ๊ณ ๋ฆฌ์ ๋ค๋ฅธ ๊ธ
[Gen AI] Diffusion Transformer (DiT) ์๋ฒฝ ์ดํดํ๊ธฐ! (0) | 2025.07.15 |
---|---|
[Gen AI] Diffusion ๋ชจ๋ธ ์ํ๋ง & ํ์ต ํธ๋ฆญ ์ ๋ฆฌ (4) | 2025.07.08 |
[Gen AI] Diffusion Model๊ณผ DDPM ๊ฐ๋ ์ค๋ช (0) | 2025.03.31 |
[๋ ผ๋ฌธ ๋ฆฌ๋ทฐ] DREAMFUSION: TEXT-TO-3D USING 2D DIFFUSION (0) | 2025.03.23 |
[๋ ผ๋ฌธ ๋ฆฌ๋ทฐ] Zero-1-to-3: Zero-shot One Image to 3D Object | Single-view object reconstruction (0) | 2025.03.22 |