1. Intro

์ต๊ทผ Back to Basics: Let Denoising Generative Models Denoise (JiT) ๋ ผ๋ฌธ์ด Diffusion ๋ถ์ผ์์ ๊ฝค ํซํ ์ฐ๊ตฌ์ด๋ค. ํต์ฌ์ ๋งค์ฐ ๋จ์ํ๋ฐ, "Diffusion ๋ชจ๋ธ์ ๋ณธ๋ ๊นจ๋ํ ์ด๋ฏธ์ง๋ฅผ ๋ณต์ํ๋ ๋ชจ๋ธ์ธ๋ฐ, ์ ๋๋ถ๋ถ์ ๊ตฌํ์ ๋ ธ์ด์ฆ(ฯต)๋ v(velocity)๋ง ์์ธกํ ๊น?" JiT๋ ๋ฐ๋ก ์ด ์ง๋ฌธ์์ ์ถ๋ฐํด, "๊ทธ๋ฅ ํด๋ฆฐ ์ด๋ฏธ์ง(x)๋ฅผ ์ง์ ์์ธกํ๋ฉด ๋ ์ ๋๋ค"๋ผ๋ ๋งค์ฐ ์ง๊ด์ ์ด์ง๋ง ๊ฐ๋ ฅํ ๊ฒฐ๋ก ์ ์ ์ํ๋ค. ํนํ ๊ณ ํด์๋ ํฝ์ ๊ณต๊ฐ์์๋ ์ด ํจ๊ณผ๊ฐ ๊ทน์ ์ผ๋ก ๋ํ๋๋ค.
1.1 ๋ฌธ์ ์์: ์ x-prediction์ธ๊ฐ?
๊ธฐ์กด diffusion ๋ชจ๋ธ์ ํฌ๊ฒ ฯต-prediction ๋๋ v-prediction์ ์ฌ์ฉํ๋ค. ๊ทธ๋ฌ๋ ์ด ๋ ๋์์ ๋ ธ์ด์ฆ๊ฐ ํฌ๊ฒ ํฌํจ๋ ๊ณ ์ฐจ์ ๋ฐ์ดํฐ(latent)์ด๋ฉฐ, ๋ชจ๋ธ์ด ์ด๋ฅผ ์ง์ ์์ธกํ๋ ๊ณผ์ ์์ ๋์ capacity๋ฅผ ์๊ตฌํ๋ค.
๋ฐ๋ฉด, ์์ฐ ์ด๋ฏธ์ง x๋ ๋ณธ์ง์ ์ผ๋ก ์ ์ฐจ์ manifold ์์ ์กด์ฌํ๋ค(๋ ผ๋ฌธ Fig. 1). ์ฆ, ๋ชจ๋ธ์ด ์์ธกํด์ผ ํ๋ ์ ๋ณด๋์ด ํจ์ฌ ์ ๋ค. ๋ฐ๋ผ์ ๋ชจ๋ธ capacity๊ฐ ์ถฉ๋ถํ์ง ์์ ์ํฉ์์๋ ์คํ๋ ค x๋ฅผ ์ง์ ์์ธกํ๋ ๊ฒ์ด ํจ์ฌ ์์ ์ ์ด๋ผ๋ ์ ์ ์ด ๋ ผ๋ฌธ์ ๋งค์ฐ ์ค๋๋ ฅ ์๊ฒ ๋ณด์ฌ์ค๋ค.
์๋ฅผ ๋ค์ด, 512×512 ์ด๋ฏธ์ง์ 32×32 ํจ์น๋ 3,072์ฐจ์์ ์ด๋ฅด๋ฉฐ, ์ด๋ฅผ ๊ทธ๋๋ก ๋ชจ๋ธ์ด ๋ค๋ฃจ๋ ๊ฒ์ ๋งค์ฐ ์ด๋ ต๋ค. ๊ทธ๋ฌ๋ x๋ ๋ ธ์ด์ฆ๋ณด๋ค ๊ตฌ์กฐ๊ฐ ๋ช ํํ๊ณ manifold ๊ตฌ์กฐ๊ฐ ์๊ธฐ ๋๋ฌธ์ ๋ชจ๋ธ์ด ์ด๋ฅผ ํ์ตํ๋ ๋ฐ ํจ์ฌ ์ ๋ฆฌํ๋ค.
1.2 JiT๊ฐ ์ ์ํ๋ ํต์ฌ ์ฒ ํ
JiT(JUST image Transformers)์ ์ฒ ํ์ ๋ช ํํ๋ค.
- ViT๋ฅผ ๊ทธ๋๋ก ์ฌ์ฉํ๋ค
- ๋ณ๋์ tokenizer, VAE, perceptual loss ํ์ ์์
- latent space๋ ์ฌ์ฉํ์ง ์๊ณ ์ค์ง ํฝ์ ๊ณต๊ฐ์์ diffusion ์ํ
- diffusion์ prediction target์ x๋ก ๊ณ ์ ํ๋ค
์ฆ, "์๋ ๊ทธ๋๋ก์ Transformer + ์๋ ๊ทธ๋๋ก์ ์ด๋ฏธ์ง" ์กฐํฉ๋ง์ผ๋ก high-resolution diffusion์ ์ฑ๊ณต์ ์ผ๋ก ์ํํ ์ ์๋ค๋ ์ ์ ์ค์ฆํ๋ค.
2. Diffusion Prediction Space ๋ถ์

๋ ผ๋ฌธ์์๋ x, ฯต, v ์ธ ๊ฐ์ง๋ฅผ prediction target์ผ๋ก ๋ ์ ์์ผ๋ฉฐ, ์ด๋ฅผ loss space์ ์กฐํฉํ๋ฉด ์ด 9๊ฐ์ง ๊ฒฝ์ฐ๊ฐ ๋๋ค๊ณ ์ ๋ฆฌํ๋ค. (Table 1) ์ธ ๊ฒฝ์ฐ๋ ์๋์ ๊ฐ์ ์ฑ๊ฒฉ์ ๊ฐ์ง๋ค.
2.1 x-prediction
- ๋ชจ๋ธ ์ถ๋ ฅ์ด ์ง์ ํด๋ฆฐ ์ด๋ฏธ์ง ๋ณต์
- manifold ์์ ๊ตฌ์กฐ์ ์ธ ๋ฐ์ดํฐ๋ฅผ ์์ธก → ํ์ต ์ฉ์ด
- 256×256 ์ด์ ๊ณ ํด์๋์์๋ ์์ ์ ์ผ๋ก ๋์
- ํนํ high-dim pixel space์์ ๋ชจ๋ธ capacity ์๊ตฌ๊ฐ ๊ฐ์ฅ ๋ฎ์
2.2 ฯต-prediction
- ๋ชจ๋ธ์ด clean image๋ฅผ ์์ธกํ์ง ์๊ณ ๋ ธ์ด์ฆ ฯต๋ฅผ ์ง์ ์์ธกํ๋ ๋ฐฉ์
- ๋ ธ์ด์ฆ๋ฅผ ์์ธกํด์ผ ํ๋ฏ๋ก ๊ณ ์ฐจ์ ๊ณต๊ฐ ์ ์ฒด๋ฅผ modeling ํ์
- latent space์์ ์ข์ง๋, ๊ณ ์ฐจ์ pixel-space์์๋ ๋ชจ๋ธ capacity๊ฐ ๋ถ์กฑํ๋ฉด catastrophic failure ๋ฐ์
- ์คํ์์ ์ค์ ๋ก FID 300 ์ด์์ผ๋ก ๋ถ๊ดด
- DDPM, DDIM, Stable Diffusion(LDM), DiT ๋ฑ ๋๋ถ๋ถ
2.3 v-prediction
- Flow matching / Rectified Flow ๋ชจ๋ธ
- ์ฌ์ ํ x๋ณด๋ค ๊ณ ์ฐจ์ ์ ๋ณด + ๋ ธ์ด์ฆ๊ฐ ์์ธ off-manifold ๊ฐ์ด๋ฏ๋ก pixel space์์๋ ฯต-prediction์ฒ๋ผ collapse ์ํ์ด ์กด์ฌ
ํต์ฌ ๊ด์ฐฐ

- 256×256 pixel-space์์ ฯต/v prediction์ ์์ ํ ๋ถ๊ดดํ์ง๋ง, x-prediction์ ์ ์ ์๋ํ๋ค. (Table 2(a))
- ๋ฐ๋ฉด 64×64 ๊ฐ์ ์ ํด์๋์์๋ capacity issue๊ฐ ๋ํด 9๊ฐ ์กฐํฉ ๋ชจ๋ ์ค์ํ ์ฑ๋ฅ์ ๋ณด์(Table 2(b))
์ฆ, “์ ์ง๊ธ๊น์ง pixel diffusion ๋ชจ๋ธ์ latent space์ ์์กดํ๋๊ฐ?”์ ๋ํ ๋ต์ ๋ช ํํด์ง๋ค. ฯต/v๋ฅผ ์ง์ ์์ธกํ๋ ๊ฒ์ ๊ณ ์ฐจ์ ๊ณต๊ฐ์์ ๋๋ฌด ํ๋ค๊ธฐ ๋๋ฌธ์ด๊ณ , JiT๋ ๊ทธ ๋ฐฉํฅ์ ๋ฐ๊ฟ x๋ฅผ ์ง์ ์์ธกํ๋ฉด ์ด ๋ฌธ์ ๊ฐ ์์ด์ง๋ค๋ ์ ์ ์คํ์ผ๋ก ์ฆ๋ช ํ ์ ์ด๋ค.
3. JiT Architecture

JiT(JuST image Transformer)๋ ์ด๋ฆ ๊ทธ๋๋ก ์ ์คํธ ์ด๋ฏธ์ง Transformer์ด๋ค. ์ง๊ธ๊น์ง์ Diffusion ์์คํ ์ด ๊ฐ์ง ๋ณต์กํ ๊ตฌ์ฑ์์๋ค(์: VAE, latent tokenizer, perceptual loss, multi-scale U-Net ๋ฑ)์ ๊ณผ๊ฐํ ์ ๊ฑฐํ๊ณ , ์๋ณธ ์ด๋ฏธ์ง(pixels)๋ง Transformer๋ก ์ง์ ์ฒ๋ฆฌํ๋ ๋ฐฉ์์ ์ฌ์ฉํ๋ค.
์ ๊ธฐํ๊ฒ๋(?) ์ด ๋จ์ํ ๊ตฌ์กฐ๊ฐ ๊ณ ํด์๋ ์ด๋ฏธ์ง ์์ฑ์์ collapse ์์ด ์์ ์ ์ผ๋ก ์๋ํ๋ค.
3.1 Patchify → ViT → Patch Reconstruction
JiT๋ ์ด๋ฏธ์ง๋ฅผ Vision Transformer(ViT)์ฒ๋ผ ๊ณ ์ ํฌ๊ธฐ ํจ์น(tokens)๋ก ๋๋์ด ์ ๋ ฅํ๋ค. ์ด๋ฏธ์ง ํฌ๊ธฐ๊ฐ 512×512๋ผ๊ณ ๊ฐ์ ํ๋ฉด, ํจ์น ์ฌ์ด์ฆ์ ๋ฐ๋ผ ์ ๋ ฅ ํ ํฐ์ ํํ๋ ์๋์ ๊ฐ๋ค.
| Patch Size | Patch Dim | Token ๊ฐ์ | ํน์ง |
| 16 x 16 | 768 | 1024 | ํ ํฐ ๊ฐ์๊ฐ ๋ง์ ๋๋ฆผ |
| 32 x 32 | 3072 | 256 | compute ํจ์จ + ํจ์น ๋จ์ ํํ๋ ฅ ๊ท ํ |
| 64 x 64 | 12288 | 64 | ํ ํฐ ๊ฐ์ ์ ์ด์ ๋น ๋ฅด์ง๋ง, ํ ํฐ ์ฐจ์์ด ๋งค์ฐ ํผ |
JiT์ ๊ธฐ๋ณธ ์ค์ ์ p=32๋ก, ๊ฐ patch๋ 3072์ฐจ์(=32×32×3)์ด๋ผ๋ ๋งค์ฐ ํฐ ๋ฒกํฐ๋ค.
์ค์ํ ํฌ์ธํธ๋ ์ผ๋ฐ Diffusion Transformer๋ latent(4~8 channels)๋ CNN feature map์ ํ ํฐ์ผ๋ก ์ฐ์ง๋ง, JiT๋ ์์ ์์ pixel ํจ์น๋ฅผ ํ ํฐ์ผ๋ก ์ฌ์ฉํ๋ค. ์ฆ, ๋ ์ด์ tokenizer๊ฐ ํ์ ์๊ณ ์ด๋ฏธ์ง ๊ทธ ์์ฒด๊ฐ ๋ชจ๋ธ์ ์ ๋ ฅ ํ ํฐ์ด๋ค. ์ด๊ฒ ๊ฐ๋ฅํ ์ด์ ๊ฐ ๋ฐ๋ก x-prediction ๊ธฐ๋ฐ์ ์์ ์ฑ ๋๋ถ์ด๋ผ๊ณ ํ๋ค.
3.2 Bottleneck embedding

3072์ฐจ์์ ํจ์น๋ฅผ ๊ทธ๋๋ก Transformer์ ๋ฃ์ผ๋ฉด ๋ฉ๋ชจ๋ฆฌ์ compute ๋น์ฉ์ด ๋งค์ฐ ํฌ๋ค. ๊ทธ๋์ JiT๋ Patch Embedding ๋จ๊ณ์์ “๋ณ๋ชฉ(bottleneck)” ๊ตฌ์กฐ๋ฅผ ์ฌ์ฉํ๋ค.
3072 (raw patch)
→ 32 (bottleneck)
→ 768 (transformer hidden dim)
๋๋๊ฒ๋ bottleneck ํฌ๊ธฐ๋ฅผ ๊ทน๋จ์ ์ผ๋ก ์ค์ฌ๋ ์ฑ๋ฅ์ด ๊ฑฐ์ ๋จ์ด์ง์ง ์๋๋ค. (๋ ผ๋ฌธ Figure 4: d′=32๋ก ์ค์ฌ๋ ImageNet FID ๊ฐ์ )
๋ ผ๋ฌธ์์ ๊ฐ์กฐํ๋ฏ clean image x๋ ์๋ low-dimensional manifold ์์ ์๊ธฐ ๋๋ฌธ์ raw pixel ์ ๋ณด ์ ์ฒด๋ฅผ ๋ณด์กดํ ํ์๊ฐ ์๋ค. ์ฆ, patch์ ๋ชจ๋ ๋ํ ์ผ์ ์ ์งํ ํ์๊ฐ ์๊ณ manifold ๊ตฌ์กฐ๋ง ์ ์ถ์ถํ๋ฉด Transformer๊ฐ ์์ ์ ์ผ๋ก ๋ณต์ํ ์ ์๋ค.
์ด ์คํ์ “pixel diffusion์ patch dimension ๋๋ฌธ์ ๋ถ๊ฐ๋ฅํ๋ค”๋ ๊ธฐ์กด ์๊ฐ์ด ํ๋ ธ์์ ๋ณด์ฌ์ค๋ค.
3.3 Transformer Backbone — Plain ViT, But Diffusionized
Patch Embedding ์ดํ์๋ ๊ฑฐ์ ๊ทธ๋๋ก์ ViT๊ฐ ์ฌ์ฉ๋๋ค.
- SwiGLU FFN
- RMSNorm
- qk-Norm
- Rotary Positional Embedding (RoPE)
- AdaLN-zero ํด๋์ค conditioning
์ฆ, ๊ณ ๋ํ๋ U-Net ๊ตฌ์กฐ๋ latent ํนํ ๋ชจ๋์ด ์๋๋ผ, ์ฌ์ค์ ์ผ๋ฐ ์ธ์ด ๋ชจ๋ธ/๋น์ ๋ชจ๋ธ๊ณผ ๋์ผํ Transformer๋ก diffusion์ ์ํํ๋ค. ์ด๋ ๋ชจ๋ธ ๊ตฌ์กฐ๊ฐ ํน์ ๋๋ฉ์ธ์ ์ข ์๋์ง ์๋๋ค๋ ์๋ฏธ์ด๊ธฐ๋ ํ๋ค.
3.4 Output: ๋ชจ๋ธ์ด ์์ธกํ๋ ๊ฒ์ ํญ์ Clean Image Patch(x_pred)
์ฌ๊ธฐ์ JiT์ ํต์ฌ์ด ๋๋ฌ๋๋๋ฐ, Transformer๋ ๋งค ์คํ ๋ง๋ค noisy image z_t๋ฅผ ๋ฐ์ clean image์ ํจ์น(x_pred)๋ฅผ ์ง์ ์์ธกํ๋ค.
x_pred = net(z_t, t)
๋ฌผ๋ก ์ด๋ฌํ ๋ฐฉ๋ฒ์ผ๋ก ํ ๋ฒ์ ์ด๋ฏธ์ง๊ฐ ๋ณต์๋๋ ๊ฒ์ ์๋๊ณ x_pred๋ ์ต์ข ๊ฒฐ๊ณผ๋ฌผ์ด ์๋๋ผ flow ๊ณ์ฐ์ ์ํ ์ค๊ฐ ์ถ์ ์น๋ค.
3.5 v-loss ๊ธฐ๋ฐ Flow Matching — x_pred → v_pred

Transformer๊ฐ ์์ธกํ x_pred๋ก๋ถํฐ velocity๋ฅผ ๊ณ์ฐํ๋ค.
v_pred = (x_pred - z_t) / (1 - t)
๊ทธ๋ฆฌ๊ณ ์ ๋ต v์ ๋น๊ตํด v-loss๋ก ํ์ตํ๋ค.
์ด ๊ณผ์ ์ด ์ค์ํ ์ด์ ๋
- x-prediction์ด pixel-space์์ ์์ ์
- v-loss๊ฐ gradient ๊ท ํ์ ๋ง์ถค
- flow matching ODE sampling๊ณผ ์์ฐ์ค๋ฝ๊ฒ ์ฐ๊ฒฐ๋จ
๋ค์ ๋งํด, “์์ธก์ x๋ก, ํ์ต์ v๋ก" ํ๋ ๊ฒ์ด๋ค.
3.6 Sampling: multi-step ODE solver (Heun/Euler)
JiT๋ “clean ์ด๋ฏธ์ง๋ฅผ ํ ๋ฒ์ ์์ธกํ๋ ๋ชจ๋ธ”์ด ์๋๋ฉฐ ์ฌ์ ํ multi-step sampling์ ์ํํ๋ค.
Sampling ์ ์ฐจ๋
- ์ด๊ธฐ noise ์ด๋ฏธ์ง zโ ์์ฑ
- patchify → embedding
- Transformer๋ก x_pred(t) ์์ธก
- x_pred(t) → v_pred ๊ณ์ฐ
- z ์ ๋ฐ์ดํธ (Heun / Euler ODE step)
- ๋ค์ patchifyํ์ฌ ๋ฐ๋ณต
- 50 step ์ ๋ ์ํ → ์ต์ข clean image ๋๋ฌ
4. ์คํ ๊ฒฐ๊ณผ

Figure 2. Toy Experiment๋ฅผ ๋ณด๋ฉด 2์ฐจ์(2D) ๋ฐ์ดํฐ ๋ถํฌ๋ฅผ ๋ง๋ค๊ณ , ์ด๋ฅผ ๋ฌด์์ projection matrix๋ก 256D, 1024D, 4096D ๊ฐ์ ๊ณ ์ฐจ์ ๊ณต๊ฐ์ผ๋ก ๋งคํํ ๋ค, ์ด ๊ณ ์ฐจ์ ๋ฐ์ดํฐ๋ฅผ ๋ณด๊ณ 3๊ฐ์ง ๋ฐฉ์(x/ฯต/v)์ผ๋ก ๋ค์ ์๋ณธ ๋ถํฌ๋ฅผ ๋ณต์ํ๋๋ก ํ์ต์์ผฐ๋ค.
๊ฒฐ๊ณผ๋ ๋งค์ฐ ์ง๊ด์ ์ธ๋ฐ,
- x-prediction: D๊ฐ ์๋ฌด๋ฆฌ ์ปค์ ธ๋ ์๋์ 2D manifold๋ฅผ ์ ํํ ๋ณต์
- ฯต-prediction: ๋ฐ์ดํฐ๊ฐ blob ํํ๋ก ๋ถ๊ดด
- v-prediction: ๊ณ ์ฐจ์์์ ๊ฑฐ์ collapse
์ฆ, ๋ณธ์ง์ ์ผ๋ก low-dimensionalํ ๊ตฌ์กฐ(x)๋ฅผ ์์ธกํ๋ ๊ฒ์ ์ฌ์ด ๋ฐ๋ฉด, ๊ณ ์ฐจ์ ์ ์ฒด์ ํผ์ง ๋ ธ์ด์ฆ(ฯต/v)๋ฅผ ์ง์ ์์ธกํ๋ ๊ฒ์ ๋งค์ฐ ์ด๋ ต๋ค๋ ์ฌ์ค์ ๋ณด์ฌ์ค๋ค. ์ด ์คํ์ “pixel diffusion์ด ์คํจํ ์ด์ ๊ฐ compute ๋๋ฌธ์ด ์๋๋ผ ์์ธก target ์๋ชป ๋๋ฌธ”์ด๋ผ๋ JiT ๋ ผ๋ฌธ์ ํต์ฌ ์ฃผ์ฅ์ ์ง๊ด์ ์ผ๋ก ์ฆ๋ช ํ๋ค.


์คํ ๊ฒฐ๊ณผ๋ฅผ ๋ณด๋ฉด FID ๊ธฐ์ค์ผ๋ก Latent Diffusion(DiT-XL/2) ๋๋น FLOPs๊ฐ ํจ์ฌ ๋ฎ์๋ฐ ์ฑ๋ฅ์ ๋๊ธ์ธ ๊ฒ์ ํ์ธํฆ ์ ์๋ค. VAE ์์ด pixel-space์์ ์ด ์์ค์ ์ฑ๋ฅ์ด ๋์จ ๊ฒ ์์ฒด๊ฐ ํต์ฌ์ ์ธ ๋ถ๋ถ์ด๋ฉฐ, ํนํ 1024×1024 pixel diffusion์ด collapse ์์ด ๋์๊ฐ ์ฒซ ์ฌ๋ก ์ค ํ๋๋ผ๊ณ ๋ณผ ์ ์๋ค.

Figure 7์ ๋ณด๋ฉด x-prediction์ด v-prediction๋ณด๋ค loss๊ฐ ๋ ๋ฎ๊ณ ์์ ์ ์ธ ๊ฒ์ ํ์ธํ ์ ์๊ณ , ์ค์ ๋ณต์ ๊ฒฐ๊ณผ๋ x-prediction์ด t๊ฐ ๋ฎ์ ๋ ์กฐ๊ธ ๋ ํ์ง์ด ์ข๋ค.

์ด ๋ ผ๋ฌธ์ ๊ทธ๋์ pixel diffusion์ด ์ด๋ ค์ ๋ ์ด์ ๋ patch dimension์ด๋ compute ๋ฌธ์ ๋๋ฌธ์ด ์๋๋ผ, ฯต/v ๊ฐ์ off-manifold, high-dimensional target์ ์์ธกํ๊ธฐ ๋๋ฌธ์ด๋ผ๋ ์ ์ ๋ช ํํ๊ฒ ๋ณด์ฌ์ค๋ค. ์ฆ, pixel diffusion์ ํ๊ณ๋ ๊ตฌ์กฐ์ ํ๊ณ๊ฐ ์๋๋ผ ์ค๊ณ์ ์ ํ์ ๋ฌธ์ ์๋ค๋ ์ธ์ฌ์ดํธ๋ฅผ ์ ์ํ๋ค.
๋ฌผ๋ก ์์ฌ์ด ์ ์ ImageNet ์ค์ผ์ผ์์๋ง ์คํ์ด ์งํ๋์ด T2I๋ multi-modal conditioning์ ๊ดํ ์คํ์ด ์๋ค๋ ์ ๊ณผ GPU ๋ฉ๋ชจ๋ฆฌ ํจ์จ์ LDM ๊ณ์ด๋ณด๋ค ๋จ์ด์ง ๊ฐ๋ฅ์ฑ์ด ๋๋ค๋ ์ ๋ ์๋ค. ํ์ง๋ง, ์ด ๋ฐฉํฅ์ด ์ณ๋ค๋ฉด ์์ผ๋ก ๊ด๋ จ๋ ์ฐ๊ตฌ๊ฐ ๋ ์งํ๋์ง ์์๊น ์ถ๋ค.