Diffusion Model¶
The Langevin Dynamics can be used to sample from a distribution
From the previous lecture, we know that the diffusion process
will converge to a Gaussian distribution. Imagining you add noise to an image, it will be blurred gradually to a white noise image as follows:
The key idea of diffusion generative model is to learn the reverse process of the diffusion process. We can generate images of cats by removing the noise from the image.
From Noise to Data¶
Michelangelo:
Every block of stone has a statue inside it. I just chip away the stone that doesn’t look like David.
Let’s say we want to build a generative model as a process that transforms a random noise vector
Prompt: "a hyper realistic twitter post by Michelangelo . include a selfie of him chopping the stones on a halfway done David statue."
We can imagine this process as “sculpting,” where the random noise
This process is hard, which is why there’s so much research on generative models. But as the saying goes, “destruction is easier than construction.” Maybe you can’t build a skyscraper, but you can definitely tear one down. So let’s think about the reverse process of dismantling a skyscraper into bricks and cement.
Let
The challenge of building a David statue is that going from raw materials
So, if we know the transformation
Denoising Diffusion Probabilistic Models¶
We now are ready to introduce the Denoising Diffusion Probabilistic Models (DDPM) to implement the reverse process.
As the saying goes, “one bite at a time.” As another saying goes, "Rome wasn't built in a day." DDPM follows this principle by defining a gradual transformation from data samples to noise (dismantling), then learns the reverse (chipping). So it’s more accurate to call DDPM a “gradual model” rather than a “diffusion model.”
Forward Process¶
Specifically, DDPM defines the dismantling process as:
Here,
Repeating this step, we get:
Why do we require
This makes computing
Reverse Process¶
Now that we have data pairs
The dismantling equation:
motivates us to design the reverse model as:
We aim to model the stone-to-be-chipped
The loss in (2) becomes:
Now we get the training loss:
Variance-Reduction Trick¶
From the loss in (3), we can see that the variance is high because of many things to sample:
from the data distribution from from the normal distribution
The more to sample, the higher the variance of loos will be, and harder to train.
To reduce the variance, we can use the variance-reduction trick proposed in the DDPM paper. You will see it is basic change of variables and Gaussian addition tricks.
Using the earlier expression for
We get the training loss:
and it is easy to check that
We construct two new Gaussian noises:
We can also represent
Substitute back to the loss in (3) and simplify to get the final DDPM loss:
where the last equality holds by taking the expectation out of
This gives us the final DDPM loss:
We can see that we need to train the noise predictor
Sampling¶
Once
You can do the random sampling as follows:
Code Implementation¶
We can implement the DDPM model in PyTorch by defining a class DDPM
with both the noise predictor and sampling.
import torch
from torch import nn, Tensor
class DDPM(nn.Module):
def __init__(self, dim: int = 2, h: int = 64, n_steps: int = 100):
super().__init__()
self.n_steps = n_steps
# Define beta schedule from small to large values
self.betas = torch.linspace(1e-4, 0.02, n_steps)
# Calculate alphas: α_t = 1 - β_t
self.alphas = 1.0 - self.betas
# Calculate cumulative product of alphas: ᾱ_t = ∏_{i=1}^t α_i
self.alpha_bars = torch.cumprod(self.alphas, dim=0)
# Simple MLP network for noise prediction ε_θ
self.net = nn.Sequential(
nn.Linear(dim + 1, h), nn.ELU(),
nn.Linear(h, h), nn.ELU(),
nn.Linear(h, h), nn.ELU(),
nn.Linear(h, dim)
)
def forward(self, t: Tensor, x_t: Tensor) -> Tensor:
# Reshape time step and concatenate with noisy input
# This implements ε_θ(x_t, t)
t = t.view(-1, 1)
return self.net(torch.cat((t, x_t), dim=-1))
def sample_step(self, x_t: Tensor, t: int) -> Tensor:
# Sample Gaussian noise for the stochastic part of sampling
noise = torch.randn_like(x_t)
# Get α_t and ᾱ_t for current timestep
alpha_t = self.alphas[t]
alpha_bar_t = self.alpha_bars[t]
# Calculate coefficient for the noise prediction
coeff = (1 - alpha_t) / torch.sqrt(1 - alpha_bar_t)
# Normalize time step to [0,1] range for the model
t_tensor = torch.full((x_t.shape[0],), t / self.n_steps, device=x_t.device)
# Predict noise using the model: ε_θ(x_t, t)
predicted_noise = self(t_tensor, x_t)
# Implement the sampling formula: x_{t-1} = (x_t - coeff * ε_θ(x_t, t)) / sqrt(α_t) + noise term
x_t = (x_t - coeff * predicted_noise) / torch.sqrt(alpha_t)
# Add noise term if not the final step, implementing the stochastic sampling
return x_t + torch.sqrt(1 - alpha_t) * noise if t > 0 else x_t
We can train the DDPM model by compute the loss in (3).
from sklearn.datasets import make_moons
ddpm = DDPM()
optimizer = torch.optim.Adam(ddpm.parameters(), lr=1e-3)
loss_fn = nn.MSELoss()
for _ in range(10000):
x_0 = Tensor(make_moons(256, noise=0.05)[0])
t = torch.randint(0, ddpm.n_steps, (x_0.shape[0],))
noise = torch.randn_like(x_0)
alpha_bar_t = ddpm.alpha_bars[t].view(-1, 1)
x_t = torch.sqrt(alpha_bar_t) * x_0 + torch.sqrt(1 - alpha_bar_t) * noise
optimizer.zero_grad()
t_normalized = t / ddpm.n_steps
predicted_noise = ddpm(t_normalized, x_t)
loss = loss_fn(predicted_noise, noise)
loss.backward()
optimizer.step()
Contextual DDPM¶
We can generalize the DDPM model to generate the conditional distribution
We can just simple add the context
More technologies like the CLIP can be used to improve the quality of the generated images. Please refer to the OpenAI DALL-E paper for more details.
Classifier-Free Guidance¶
For conditional probabiltiy generation, there is a trade-off between the fidelity and mode-coverage (diversity) of the generated images. In order to tune the trade-off, we can use the classifier-free guidance to sample using a linear combination of conditional and unconditional samples:
where
The training process can be summarized as follows.
Input:
Repeat:
- Sample data with conditioning from the dataset:
- Randomly discard conditioning to train unconditionally:
with probability - Sample log SNR value:
- Sample Gaussian noise:
- Corrupt data to the sampled log SNR value:
- Take gradient step on
Until converged
We will then use
When