Flow Matching¶
In the diffusion model, the forward process adding noise to the data distribution as
where \(\alpha_t\) is a function of \(t\).
This process connects the data distribution \(p_0(x)\) to Gaussian distribution \(p_{\infty}(x)\) as \(x_t\) for \(t \to \infty\). And then DDPM aims to reverse this process to start from Gaussian distribution \(p_{\infty}(x)\) and recover the data distribution \(p_0(x)\) by training a denoising diffusion model.
However, there are two problems for the DDPM:
- We add noise in each step in the forward process, which adds variance to the process.
- \(p_T(x)\) is close but not exactly a Gaussian distribution, which introduces bias to the reverse process.
To overcome these two problems, Flow matching's idea is to directly interpolate between the data density \(p_1(x)\) and the Gaussian density (or other simple density) \(p_0(x)\).
So how do we interpolate between \(p_1(x)\) and \(p_0(x)\)? From the previous lecture, we know that a density \(p_t(x)\) can be evolved following a vector field \(u_t(x)\):
The interpolation between \(X_0 \sim p_0(x)\) and \(X_1 \sim p_1(x)\) can be achieved by many ways but the simplest way is to use a linear interpolation:
We can learn the vector field \(u_t(x)\) by a neural network model \(u^{\theta}_t(x)\) or other methods and minimize the following Flow Matching loss:
However, the vector field \(u_t(x)\) is usually too complicated to make the above loss possible to solve.
Conditional Flow Matching¶
To overcome this problem, we now consider a simpler case for the target density \(p_1(x)\) being a singleton point \(x_1\). The interpolation between \(X_0 \sim p_0(x) = N(0, I)\) and \(X_1 = x_1\) becomes:
As we have
so we have the conditional vector field
It can be shown that if \(p_0(x)\) evolves over the vector field \(u_{t}(x | x_1)\), it will converge to \(x_1\) at \(t=1\).
We can then take the above example as a conditional case for the general target density \(X_1 \sim p_1(x)\). Each \(X_{t|1}\) is a conditional path and we mix all paths \(p_t(x|x_1)\) to get the final density \(p_t(x)\).
We consider the conditional flow matching loss:
It can be shown that \(\mathcal{L}_{\text{CFM}}\) is same as the marginal flow matching loss \(\mathcal{L}_{\text{FM}}\) up to a constant:
Plugging the conditional vector field \(u_t(X_t|X_1) = X_1 - X_0\) into the CFM loss, we get:
Code Implementation¶
We can implement the CFM in PyTorch by defining a class with both the noise predictor and sampling.
import torch
from torch import nn, Tensor
class Flow(nn.Module):
def __init__(self, dim: int = 2, h: int = 64):
super().__init__()
self.net = nn.Sequential(
nn.Linear(dim + 1, h), nn.ELU(),
nn.Linear(h, h), nn.ELU(),
nn.Linear(h, h), nn.ELU(),
nn.Linear(h, dim))
def forward(self, t: Tensor, x_t: Tensor) -> Tensor:
return self.net(torch.cat((t, x_t), -1))
def step(self, x_t: Tensor, t_start: Tensor, t_end: Tensor) -> Tensor:
t_start = t_start.view(1, 1).expand(x_t.shape[0], 1)
return x_t + (t_end - t_start) * self(t=t_start, x_t= x_t)
The we can train the model by minimizing the CFM loss.
flow = Flow()
optimizer = torch.optim.Adam(flow.parameters(), 1e-2)
loss_fn = nn.MSELoss()
for _ in range(10000):
x_1 = Tensor(make_moons(256, noise=0.05)[0])
x_0 = torch.randn_like(x_1)
t = torch.rand(len(x_1), 1)
x_t = (1 - t) * x_0 + t * x_1
dx_t = x_1 - x_0
optimizer.zero_grad()
loss_fn(flow(t=t, x_t=x_t), dx_t).backward()
optimizer.step()
The animation below compares the flow matching and the DDPM on the moons dataset. You can see that the flow matching is more stable and efficient.
Flow Matching | DDPM |
---|---|
![]() |
![]() |
General Case¶
The linear interpolation of course is the simplest case. Summarizing the above, we have the following general strategy: we first construct an interpolation between \(p_0(x)\) and a singleton point \(x_1\) and then we find the corresponding conditional vector field \(u_t(x|x_1)\).
Following this, we have the general form of the density \(p_t\) interpolated between \(N(0, I)\) and \(x_1\). Suppose we want conditional vector field which generates a path of Gaussians, i.e.,
where \(\mu_0(x_1) = 0\), \(\mu_1(x_1) = x_1\) and \(\sigma_0(x_1) = 1\), \(\sigma_1(x_1) = \sigma_{\min}\). Here we choose a sufficiently small \(\sigma_{\min}\) and use \(N(x_1, \sigma_{\min}^2 I)\) to approximate the singleton point \(x_1\).
Namely, we have
So we has the ordinary differential equation:
This implies that the conditional vector field is:`
We then have the general conditional flow matching loss:
where \(X_t = \sigma_t(X_1) X_0 + \mu_t(X_1)\).
Example: Optimal Transport conditional vector field¶
The aforementioned linear interpolation considers:
Thus the conditional vector field is:
The CFM loss becomes:
Example: Diffusion Conditional Vector Field¶
Recalling the forward process of the diffusion model, we have
Note that we flipped the notation compared to the previous lecture: \(X_1\) now is the simple Gaussian distribution \(N(0, I)\) and \(X_0\) is the data distribution. So we need to reverse the time of weight as \(\alpha_{1-t}\).
Then we have the conditional vector field:
We usually choose \(\alpha_t = \exp(-\frac{1}{2}T(t))\), where \(T(t) = \int_0^t \beta(s) ds\). Such choice corresponds to the process (see the paper for more details):
Therefore, we have the conditional vector field:
Mathematically, this is equivalent to the DDPM. However, it has been shown that the flow matching approach is more stable and efficient.