Skip to content

Diffusion Model

The Langevin Dynamics can be used to sample from a distribution pexp(f) when f is known. However, for most cases, we need to learn the distribution from finite samples, e.g., images, where the distribution is unknown. This lecture introduces a diffusion generative model based on the diffusion process, which is a powerful tool for learning unknown distributions.

From the previous lecture, we know that the diffusion process

Xt=X0+2tεt,εtN(0,I)

will converge to a Gaussian distribution. Imagining you add noise to an image, it will be blurred gradually to a white noise image as follows:

noise

The key idea of diffusion generative model is to learn the reverse process of the diffusion process. We can generate images of cats by removing the noise from the image.

diffusion

From Noise to Data

Michelangelo:

Every block of stone has a statue inside it. I just chip away the stone that doesn’t look like David.

Let’s say we want to build a generative model as a process that transforms a random noise vector z into a data sample x. Imagine you are Michelangelo, and you want to build a David statue from a block of stone.

David

Prompt: "a hyper realistic twitter post by Michelangelo . include a selfie of him chopping the stones on a halfway done David statue."

We can imagine this process as “sculpting,” where the random noise z is raw materials like stones, and the data sample x is the sculpture. So, a generative model is like Michelangelo who removes the stone to reveal the sculpture.

Random noise zreverseSample data xRockchipSculture

This process is hard, which is why there’s so much research on generative models. But as the saying goes, “destruction is easier than construction.” Maybe you can’t build a skyscraper, but you can definitely tear one down. So let’s think about the reverse process of dismantling a skyscraper into bricks and cement.

Let x0 be the finished sculpture (data sample), and xT be the pile of stones (random noise). Assume it takes T steps to dismantle it. The entire process is:

x=x0x1x2xT1xT=z

forward

The challenge of building a David statue is that going from raw materials xT to the final structure x0 is too big a leap. But if we have the intermediate states x1,x2,,xT, we can understand how to go from one step to the next. Even master like Michelangelo need to chip one block at a time.

So, if we know the transformation xt1xt (dismantling), then reversing it xtxt1 is like chipping. If we can learn the reverse function μ(xt), then starting from xT, we can repeatedly apply μ to reconstruct the David statue:

xT1=μ(xT),xT2=μ(xT1),

Denoising Diffusion Probabilistic Models

We now are ready to introduce the Denoising Diffusion Probabilistic Models (DDPM) to implement the reverse process.

As the saying goes, “one bite at a time.” As another saying goes, "Rome wasn't built in a day." DDPM follows this principle by defining a gradual transformation from data samples to noise (dismantling), then learns the reverse (chipping). So it’s more accurate to call DDPM a “gradual model” rather than a “diffusion model.”

Forward Process

Specifically, DDPM defines the dismantling process as:

xt=αtxt1+βtεt,εtN(0,I)

Here, αt,βt>0 and αt2+βt2=1. Typically, βt is close to 0, representing small degradation at each step. The noise εt adds randomness—think of it as raw material injected at each step.

Repeating this step, we get:

xt=αtxt1+βtεt=αt(αt1xt2+βt1εt1)+βtεt=(αtαt1)xt2+(αtβt1+βt)εt1+βtεt==(αtα1)x0+(αtα2)ε1+(αtα3)ε2++βtεta weighted sum of Gaussian noises

Why do we require αt2+βt2=1? Because the Gaussian noise sum then becomes a single Gaussian with mean 0 and total variance 1, i.e.,

(1)xt=(αtα1)α¯tx0+1(αtα1)2β¯tε¯t,ε¯tN(0,I)

This makes computing xt very convenient. Furthermore, α¯T0, meaning that after T steps, only noise remains.

forward

Reverse Process

Now that we have data pairs (xt1,xt) from dismantling, we can learn the reverse xtxt1 via a model μ(xt). The loss is:

(2)xt1μ(xt)2

The dismantling equation:

xt1=1αt(xtβtεt)

motivates us to design the reverse model as:

μ(xt)=1αt(xtβtεθ(xt,t))

We aim to model the stone-to-be-chipped εt as εθ(xt,t).

The loss in (2) becomes:

εtεθ(xt,t)2

Now we get the training loss:

(3)Ex0Data,tUniform{1,,T},εtN(0,I)εtεθ(xt,t)2

Variance-Reduction Trick

From the loss in (3), we can see that the variance is high because of many things to sample:

  • x0 from the data distribution
  • t from 1,2,,T
  • εt,t=1,2,,T from the normal distribution

The more to sample, the higher the variance of loos will be, and harder to train.

To reduce the variance, we can use the variance-reduction trick proposed in the DDPM paper. You will see it is basic change of variables and Gaussian addition tricks.

Using the earlier expression for xt in (1):

xt=αtxt1+βtεt=αt(α¯t1x0+β¯t1ε¯t1)+βtεt=α¯tx0+αtβ¯t1ε¯t1+βtεt

We get the training loss:

εtεθ(α¯tx0+αtβ¯t1ε¯t1+βtεt,t)2

and it is easy to check that ω,εN(0,I) and E[εωT]=0. So ω is independent of ε.

We construct two new Gaussian noises:

β¯tε=αtβ¯t1ε¯t1+βtεtβ¯tω=βtε¯t1αtβ¯t1εt

We can also represent εt by ω and ε as:

εt=βtεαtβ¯t1ωβ¯t

Substitute back to the loss in (3) and simplify to get the final DDPM loss:

EεtN(0,I)εtεθ(xt,t)2(4)=Eε,ωN(0,I)βtεαtβ¯t1ωβ¯tβ¯tβtεθ(α¯tx0+β¯tε,t)2=β¯t2βt2EεN(0,I)εεθ(α¯tx0+β¯tε,t)2+const,

where the last equality holds by taking the expectation out of ω by noticing the loss in (4) is quadratic in ω.

This gives us the final DDPM loss:

Ex0Data,tUniform{1,,T},εN(0,I)εεθ(α¯tx0+β¯tε,t)2

We can see that we need to train the noise predictor εθ with the time step t as input. In the original DDPM paper, the time step t is specified by adding the Transformer sinusoidal position embedding into each residual block. The paper also suggests to choose T=1000 and αt=10.02t/T to choose smaller steps when closer to the original data distribution.

Sampling

Once εθ is trained, DDPM generates samples by starting from xTN(0,I) and running:

xt1=1αt(xtβtεθ(xt,t))

You can do the random sampling as follows:

xt1=1αt(xtβtεθ(xt,t))+βtz,zN(0,I).

ddpm_sample

ddpm_sample

Code Implementation

We can implement the DDPM model in PyTorch by defining a class DDPM with both the noise predictor and sampling.

import torch 
from torch import nn, Tensor

class DDPM(nn.Module):
    def __init__(self, dim: int = 2, h: int = 64, n_steps: int = 100):
        super().__init__()
        self.n_steps = n_steps
        # Define beta schedule from small to large values
        self.betas = torch.linspace(1e-4, 0.02, n_steps)
        # Calculate alphas: α_t = 1 - β_t
        self.alphas = 1.0 - self.betas
        # Calculate cumulative product of alphas: ᾱ_t = ∏_{i=1}^t α_i
        self.alpha_bars = torch.cumprod(self.alphas, dim=0)

        # Simple MLP network for noise prediction ε_θ
        self.net = nn.Sequential(
            nn.Linear(dim + 1, h), nn.ELU(),
            nn.Linear(h, h), nn.ELU(),
            nn.Linear(h, h), nn.ELU(),
            nn.Linear(h, dim)
        )

    def forward(self, t: Tensor, x_t: Tensor) -> Tensor:
        # Reshape time step and concatenate with noisy input
        # This implements ε_θ(x_t, t)
        t = t.view(-1, 1)
        return self.net(torch.cat((t, x_t), dim=-1))

    def sample_step(self, x_t: Tensor, t: int) -> Tensor:
        # Sample Gaussian noise for the stochastic part of sampling
        noise = torch.randn_like(x_t)
        # Get α_t and ᾱ_t for current timestep
        alpha_t = self.alphas[t]
        alpha_bar_t = self.alpha_bars[t]
        # Calculate coefficient for the noise prediction
        coeff = (1 - alpha_t) / torch.sqrt(1 - alpha_bar_t)
        # Normalize time step to [0,1] range for the model
        t_tensor = torch.full((x_t.shape[0],), t / self.n_steps, device=x_t.device)
        # Predict noise using the model: ε_θ(x_t, t)
        predicted_noise = self(t_tensor, x_t)
        # Implement the sampling formula: x_{t-1} = (x_t - coeff * ε_θ(x_t, t)) / sqrt(α_t) + noise term
        x_t = (x_t - coeff * predicted_noise) / torch.sqrt(alpha_t)
        # Add noise term if not the final step, implementing the stochastic sampling
        return x_t + torch.sqrt(1 - alpha_t) * noise if t > 0 else x_t

We can train the DDPM model by compute the loss in (3).

from sklearn.datasets import make_moons

ddpm = DDPM()
optimizer = torch.optim.Adam(ddpm.parameters(), lr=1e-3)
loss_fn = nn.MSELoss()

for _ in range(10000):
    x_0 = Tensor(make_moons(256, noise=0.05)[0])
    t = torch.randint(0, ddpm.n_steps, (x_0.shape[0],))
    noise = torch.randn_like(x_0)

    alpha_bar_t = ddpm.alpha_bars[t].view(-1, 1)
    x_t = torch.sqrt(alpha_bar_t) * x_0 + torch.sqrt(1 - alpha_bar_t) * noise

    optimizer.zero_grad()
    t_normalized = t / ddpm.n_steps
    predicted_noise = ddpm(t_normalized, x_t)
    loss = loss_fn(predicted_noise, noise)
    loss.backward()
    optimizer.step()

Contextual DDPM

We can generalize the DDPM model to generate the conditional distribution p(x|c) where c could be even be a text prompt.

We can just simple add the context c to the input of the noise predictor εθ(xt,t,c) and train the loss function as:

E(x0,c)Data,tUniform{1,,T},εN(0,I)εεθ(α¯tx0+β¯tε,t,c)2

More technologies like the CLIP can be used to improve the quality of the generated images. Please refer to the OpenAI DALL-E paper for more details.

dalle

Classifier-Free Guidance

For conditional probabiltiy generation, there is a trade-off between the fidelity and mode-coverage (diversity) of the generated images. In order to tune the trade-off, we can use the classifier-free guidance to sample using a linear combination of conditional and unconditional samples:

ε~θ(xt,t,c)=(1+w)εθ(xt,t,c)wεθ(xt,t),

where εθ(xt,t) is an unconditional noise predictor. Usually, we will use the same network for both conditional and unconditional cases. For unconditional case, we will use a null token as the context c and fit εθ(xt,t)=εθ(xt,t,).

The training process can be summarized as follows.

Input: puncond: probability of unconditional training

Repeat:

  1. Sample data with conditioning from the dataset: (x,c)p(x,c)
  2. Randomly discard conditioning to train unconditionally: c with probability puncond
  3. Sample log SNR value: λp(λ)
  4. Sample Gaussian noise: ϵN(0,I)
  5. Corrupt data to the sampled log SNR value: zλ=αλx+σλϵ
  6. Take gradient step on θϵθ(zλ,c)ϵ2

Until converged

We will then use ε~θ(xt,t,c) to sample from the model.

When w increases from 0 to , the generated images will become less fidelity and more diversity.