Hello! This self-contained notebook will derive how to arrive at a skeleton of contemporary continuous-variable generative models from scratch.
You are certain to have interacted in some form or another with some artificially generated content in your daily life. It may be some you made yourself, for creative purposes like the the growing flower depicted in the video above courtesy of OpenAI’s Sora model, or one you came across online (for better or for worse). Below we will go through a pedagogical tour through some of the mathematics behind these models and a step-by-step procedure of how to build them ourselves.
Agenda¶
- We’ll start with stating what the learning problem is -- finding a map between to probability densities, and describe some ways people have thought about constructing it in the past.
- Because we will suggest a continuous-time mapping between the two distributions, we’ll lay out some of the mathematics underpinning this, separate from any generative models or neural networks. This will give us an intuition about ODEs and SDEs.
- Then we’ll explain how these ideas can be used to build generative models. From there, we’ll build a toy generative model to complete the picture. Then, learn about more expressive neural networks used to in large scale image generation technologies built on these methods.
Dynamical Transport as a generative model¶
Let’s state plainly what our goals are: given some data that are sampled according to some unknown probability density , learn a model for this distribution that allows us to draw new samples from it. A salient way to do this is to learn a map that connects points from some simple distribution with density to points . The characteristics of this map will thereby also tell us how is adapted into some .
The predominant way to formulate this problem in recent literature is to think about a continuous adaptation of into , which we will call indexed by a some time variable . The paradigmatic method for doing this is score-based/denoising diffusion (Ho et al. (2020), Song et al. (2020)), in which independent Gaussian noise is iteratively denoised into an image. On a per-sample basis, generating an image of a flower, say, might look like what is depicted in Figure 2.
In what follows we’ll analyze the equations that the density and the sample must obey to properly perform this generation. Along the way, we’ll derive some algorithms for learning models of the coefficients in the resulting equations that will give us our generative model.
Transporting densities and their samples: Continuity equation and characteristic curves¶
However we adapt into , we want to ensure that it remains a valid probability density, meaning that the evolution of conserves probability mass. Let’s think about how that might come about by imagining probability mass as a fluid, in a simple 3-dimensional example. Assume we have an incremental volume element of 3D space as depicted in Figure 3. Given that is the density of probability mass at time and at point , we can label the direction and magnitude of motion of that mass with a velocity field .
Let’s consider motion only in the direction as per the figure. By dimensional analysis we can conclude that the incremental flux of probability entering the incremental volume element through the plane is given by , where the superscript is used to label the velocity in the direction. This flux may change at an incremental element farther along in space, which we can estimate through taylor expansion as:
where we only keep the first order term. Since the term on the RHS is the net flux out of the incremental volume, we can conclude that the change in probability mass at any time due to flow in the direction reads as . The same holds for flux in the and directions. Writing this out explicitly as the time derivative of the probability density, we have
This is the continuity equation for the density and the velocity field , where we have included the initial condition that at time , we start from our base distribution . This equation is essential an equation of probability mass conservation. If we want to find a map between and , then the time dependent density arising from our map must solve (2). How do we find such a and to perform the transport? Eyeballing (2), things don’t look great as this is a partial differential equation (PDE). Thankfully, there is a handy trick that goes by the method of characteristics which will allow us to instead solve a family of ordinary differential equations (ODEs) that are much easier to work with. We will apply it to state the following proposition, borrowing from how it is stated in Albergo & Vanden-Eijnden (2024):
Here, we introduced the probability flow ODE, whose solutions give us characteristic curves along which we can solve (2) and produce samples . In particular, solving the probability flow ODE up to time will give us our map to draw samples from . The proof, which elucidates this point, is given in the dropdown bar below.
Proof 1
We want to solve the PDE:
for . To do this, we want to find characteristic curves , with along which the total derivative of w.r.t can be computed with known quantities and no other derivatives, thereby reducing the PDE, which we don’t know how to solve, into an ODE, which we do!
Taking the totaly derivative of , we have using the chain rule
where we have intentionally color-coded some terms. Next, note that we can expand the divergence term in (2) so that
Terms in blue are terms with that we’d like to simplify. Where does the probability flow ODE come in? If we deliberately choose with appropriate initial condition , then we have that . Using this fact and (7) we can write
so that
This is an ODE to describe the evolution of along curves , which can be solved straightfowardly as follows.
We can treat only as a function of time because along the chracteristic curve, x is fixed as an initial condition. Then we have
Dividing by and integrating both sides gives
where we can conclude
Below, we visualize in Figure 4 solutions to the probability flow ODE to map out the characteristic curves as it goes fom to for various initial conditions.
For a fixed , we may not know what this is that we need in order to solve the probability flow ODE. In fact, we don’t know it for just about all non-trivial transport problems (except for e.g. Gaussian mixtures)! How can we go about discovering one? Below we’ll implement a learning algorithm to discover what it is that is known as stochastic interpolants or flow matching (Albergo & Vanden-Eijnden (2022), Lipman et al. (2022), Liu et al. (2022)).
Stochastic Interpolants / Flow Matching¶
As we have said above, our learning problem is: given sample and a base distribution with density , learn a model to generate samples . We just spent a lot of time analyzing solutions to the continuity equation involving the time dependent density and the velocity field to endow a flow map . To learn a map, I need to find a which solves the transport. But there’s also another problem: many such maps exist which meet the criteria that , depending on the you choose, and we haven’t discussed how to choose a !
What you’ll define below is a way to do this using the method of stochastic interpolants, which gives us a way to simultaneously choose a time dependent density and learn the over a parametric function class (neural networks) that performs the associated mapping.
Let’s lay out the steps to do this in code. Here’s what we’ll need:
- a python package that allows us to perform optimization over our parametric functions using autodifferentiation: PyTorch
- A way to define and sample under the probability densities and
- A neural network that we will use as our model of the velocity field
- An implementation of the stochastic interpolant and the associated optimization loop used in conjunction with it to fit to the true .
- Some visualizations of the results!
We will do this on a toy dataset that defines a distribution in 2-dimensions. Let’s begin!
Software pre-requisites¶
try:
import numpy as np
import matplotlib.pyplot as plt
import torch # 2.0 or greater
from torch.func import vmap
print("All packages are installed!")
except ImportError as e:
print(f"An error occurred: {e}. Please make sure all required packages are installed.")
from typing import Callable
import math
All packages are installed!
def grab(var):
return var.detach().cpu().numpy()
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(torch.cuda.is_available())
print(device)
True
cuda
Defining and sampling our two distributions¶
To learn a generative model to sample under , we first need to define what distribution we want to sample from and what distribution we start from!
Let’s say that our data is in . We will take as a Normal distribution with mean vector and identity covariance so that .
import torch
from torch.distributions import MultivariateNormal
class BaseDistribution:
def __init__(self, mean, cov):
self.mean = mean
self.covariance = cov
self.distribution = MultivariateNormal(mean, cov)
def sample(self, n=1):
"""
Draws $n$ samples from the Gaussian distribution.
"""
return self.distribution.sample((n,))
def log_prob(self, x):
"""
Evaluates the log probability of given samples $x$ under the distribution.
"""
return self.distribution.log_prob(x)
mean = torch.tensor([0.0, 0.0]).to(device) # \mu \in R^2
cov = torch.tensor([[1.0, 0.0], [0.0, 1.0]]).to(device) # \Sigma \in R^{2x2}
base = BaseDistribution(mean, cov)
Let’s choose our target distribution to be: a checkerboard in in 2d space or a Gaussian Mixture defined as follows:¶
## checker
# ndim = 2
# def target(bs):
# x1 = torch.rand(bs) * 4 - 2
# x2_ = torch.rand(bs) - torch.randint(2, (bs,)) * 2
# x2 = x2_ + (torch.floor(x1) % 2)
# return (torch.cat([x1[:, None], x2[:, None]], 1) * 2)
from torch.distributions.mixture_same_family import MixtureSameFamily
from torch.distributions.normal import Normal
from torch.distributions.categorical import Categorical
from torch.distributions.independent import Independent
from torch.distributions.multivariate_normal import MultivariateNormal
class Prior(torch.nn.Module):
"""
Abstract class for prior distributions of normalizing flows. The interface
is similar to `torch.distributions.distribution.Distribution`, but batching
is treated differently. Parameters passed to constructors are never batched,
but are aware of the target (single) sample shape. The `forward` method then
accepts just the batch size and produces a batch of samples of the known
shape.
"""
def forward(self, batch_size):
raise NotImplementedError()
def log_prob(self, x):
raise NotImplementedError()
def draw(self, batch_size):
"""Alias of `forward` to allow explicit calls."""
return self.forward(batch_size)
class GMM(Prior):
def __init__(self, loc=None, var=None, scale = 1.0, ndim = None, nmix= None, device=device, requires_grad=False):
super().__init__()
self.device = device
self.scale = scale ### only specify if loc is None
def _compute_mu(ndim):
return self.scale*torch.randn((1,ndim))
if loc is None:
self.nmix = nmix
self.ndim = ndim
loc = torch.cat([_compute_mu(ndim) for i in range(1, self.nmix + 1)], dim=0)
var = torch.stack([1.0*torch.ones((ndim,)) for i in range(nmix)])
else:
self.nmix = loc.shape[0]
self.ndim = loc.shape[1] ### locs should have shape [n_mix, ndim]
self.loc = loc ### locs should have shape [n_mix, ndim]
self.var = var ### should have shape [n_mix, ndim]
if requires_grad:
self.loc.requires_grad_()
self.var.requires_grad_()
mix = Categorical(torch.ones(self.nmix,).to(device))
comp = Independent(Normal(
self.loc, self.var), 1)
self.dist = MixtureSameFamily(mix, comp)
def log_prob(self, x):
logp = self.dist.log_prob(x)
return logp
def forward(self, batch_size):
x = self.dist.sample((batch_size,))
return x
def rsample(self, batch_size):
x = self.dist.rsample((batch_size,))
return x
nmix = 8
ndim = 2
def _compute_mu(i):
return 5.0 * torch.Tensor([[
torch.tensor(i * math.pi / 4).sin(),
torch.tensor(i * math.pi / 4).cos()]]).to(device)
mus_target = torch.stack([_compute_mu(i) for i in range(nmix)]).squeeze(1).to(device)
var_target = torch.stack([torch.tensor([0.7, 0.7]) for i in range(nmix)]).to(device)
target = GMM(mus_target, var_target)
target_samples = target(10000)
Below we can visualize both of the distributions by drawing samples from them and plotting them
bs = 5000
c = '#62508f' # plot color
x0s = grab(base.sample(bs))
x1s = grab(target(bs))
fig, (ax0, ax1) = plt.subplots(nrows=1, ncols=2, figsize=(10,5))
ax0.scatter(x0s[:,0], x0s[:,1], alpha = 0.5, c = c);
ax0.set_xlim(-4,4), ax0.set_ylim(-4,4)
ax0.set_title(r"Samples under $\rho_0$", fontsize = 16)
ax0.set_xticks([-4,0,4]), ax0.set_yticks([-4,0,4])
ax1.scatter(x1s[:,0], x1s[:,1], alpha = 0.5, c = c);
ax1.set_xlim(-8,8), ax1.set_ylim(-8,8)
ax1.set_title(r"Samples under $\rho_1$", fontsize = 16)
ax1.set_xticks([-4,0,4]), ax1.set_yticks([-4,0,4]);
plt.show();
Building a map between and with the interpolant¶
You have been learning in class that you can use the method of stochastic interpolants to define a time dependent density that connects to and allows us to learn the associated velocity field that would map samples from one to samples from another.
Let’s introduce the interpolant function here and make some statements about it. The stochastic interpolant is a stochastic process given as:
- When are drawn accordingly, we say that the stochastic interpolant is a sample under the time dependent density i.e. .
- The velocity field associated to this is given by the conditional expectation of the time dynamics of the interpolant, namely:
What do we mean when we say samples ? Well let’s take a look by implementing it. We will make a class called Interpolant
that implements and its time derivative.
## using alpha(t) = (1-t) and beta(t) = t
class Interpolant:
def alpha(self, t):
return 1.0 - t
def dotalpha(self, t):
return -1.0 + 0*t
def beta(self, t):
return t
def dotbeta(self, t):
return 1.0 + 0*t
def _single_xt(self, x0, x1, t):
return self.alpha(t)*x0 + self.beta(t)*x1
def _single_dtxt(self, x0, x1, t):
return self.dotalpha(t)*x0 + self.dotbeta(t)*x1
def xt(self, x0, x1, t):
return vmap(self._single_xt, in_dims=(0, 0, 0))(x0,x1,t)
def dtxt(self, x0, x1, t):
return vmap(self._single_dtxt, in_dims=(0, 0, 0))(x0,x1,t)
interpolant = Interpolant()
bs = 10000
x0s = base.sample(bs).to(device)
x1s = target(bs).to(device)
t = 0.2*torch.ones(bs).to(device)
xts = interpolant.xt(x0s, x1s, t)
In the above code, we implement the interpolant for a single sample of , , and , which have shape [d], [d], [1]
respectively. Then, we use a tool in pytorch called vmap
which allows us to generalize the code for arbitrary batches of samples of and . That way, if I have samples of , and , I can compute the interpolant for all of them at once using vmap.
This means I can feed the function interpolant.xt
a batch of samples of shape [N, d]
for and and of shape [N]
for time.
Let’s see what happens when I sample at various times along :
bs = 8000
ncol = 6
ts = torch.linspace(0, 1, ncol)
c = '#62508f'
fig, axes = plt.subplots(1, ncol, figsize=(ncol*4,4))
for i, t in enumerate(ts):
tt = t.repeat(bs).to(device)
x0s = base.sample(bs).to(device)
x1s = target(bs).to(device)
xts = interpolant.xt(x0s, x1s, tt)
axes[i].scatter(grab(xts[:,0]), grab(xts[:,1]), alpha = 0.08, c = c); # plot samples x_t \sim \rho_t
axes[i].set_xticks([])
axes[i].set_title(r'$\rho(t = %.1f)$' % t, fontsize = 20, weight='bold')
axes[i].set_xlim(-8,8)
axes[i].set_ylim(-8,8)
if i !=0:
axes[i].set_yticks([])
plt.show();
Neural Network to model ¶
Now we need to define a neural network which has learnable parameters to model . Recall that the velocity field takes in a time coordinate and a spatial sample which is . This means that our neural network approximation should be a function that has an input size of and should have an output dimension of .
To define a neural network, we’ll use PyTorch’s torch.nn.Module
library which allows us to compose parts of the neural network in a way that allows us to take derivatives with respect to the networks weights:
from torch.func import jacrev, grad
class VelocityField(torch.nn.Module):
# a neural network that takes x in R^d and t in [0, 1] and outputs a a value in R^d
def __init__(self, d, hidden_sizes = [256, 256], activation=torch.nn.ReLU):
super(VelocityField, self).__init__()
layers = []
prev_dim = d + 1 #
for hidden_size in hidden_sizes:
layers.append(torch.nn.Linear(prev_dim, hidden_size))
layers.append(activation())
prev_dim = hidden_size # Update last_dim for the next layer
# final layer
layers.append(torch.nn.Linear(prev_dim, d))
# Wrap all layers in a Sequential module
self.net = torch.nn.Sequential(*layers)
def _single_forward(self, x, t):
t = t.unsqueeze(-1)
return self.net(torch.cat((x, t)))
def forward(self, x, t):
return vmap(self._single_forward, in_dims=(0,0), out_dims=(0))(x,t)
Let’s check to make sure our neural network class works by making some fake data and seeing that it outputs something of the right shape
d = 2
b = VelocityField(d, hidden_sizes=[512, 512, 512]).to(device)
bs = 10 ## simple test batch size
x = torch.rand(bs, d).to(device)
t = torch.rand(bs).to(device)
print(x.shape)
print(t.shape)
out = b.forward(x,t) ## should output something of shape [bs, d]
torch.Size([10, 2])
torch.Size([10])
Loss function and learning¶
Now that we have a suitable neural network to optimize, we need to specify the learning rule to update the weights of .
Given that the velocity field of the interpolant is , it should be the unique minimizer of the least squares loss given by:
where expectation is taken over .
Let’s go ahead and implement this in code:
def _single_loss(b, interpolant, x0, x1, t):
"""
Interpolant loss function for a single datapoint of (x0, x1, t).
"""
It = interpolant._single_xt( x0, x1, t)
dtIt = interpolant._single_dtxt(x0, x1, t)
bt = b._single_forward(It, t)
loss = 0.5*torch.sum(bt**2) - torch.sum((dtIt) * bt)
return loss
loss_fn = vmap(_single_loss, in_dims=(None, None, 0, 0, 0), out_dims=(0), randomness='different')
N = 10
x0s = base.sample(N).to(device)
x1s = target(N).to(device)
ts = torch.rand(N).to(device)
loss_fn(b, interpolant, x0s, x1s, ts).mean()
tensor(0.2853, device='cuda:0', grad_fn=<MeanBackward0>)
Training step¶
Now that we have constructed our loss, let’s put it in a loop to interatively update the parameters of as we move toward the minimizer of .
To perform this parameter update, we need to introduce a pytorch optimizer
that performs the gradient update for us. We do that via the following, specifying a learning rate lr
. We use the Adam
optimizer, which is a fancier version of SGD.
lrb = 2e-3 ## learning rate
opt = torch.optim.Adam([
{'params': b.parameters(), 'lr': lrb} ])
sched = torch.optim.lr_scheduler.StepLR(opt, step_size=1500, gamma=0.6)
Now lets use it in a function called train_step
which will perform our iteration:
def train_step(b, interpolant, opt, sched, N):
## draw N samples form each distribtuion and from time uniformly
x0s = base.sample(N).to(device)
x1s = target(N).to(device)
ts = torch.rand(N).to(device)
# evaluate loss
loss_val = loss_fn(b, interpolant, x0s, x1s, ts).mean()
# perform backprop
loss_val.backward()
opt.step()
sched.step()
opt.zero_grad()
res = {
'loss': loss_val.detach(),
}
return res
Running the Training¶
We are now ready to train our model. Let’s build a loop that runs for n_opt
stepsand store the loss over time.
n_opt = 5000
bs = 2500
losses = []
from tqdm import tqdm
pbar = tqdm(range(n_opt))
for i in pbar:
res = train_step(b, interpolant, opt, sched, bs)
loss = grab(res['loss'])
losses.append(loss)
pbar.set_description(f'Loss: {loss:.4f}')
Loss: -8.8316: 100%|██████████| 5000/5000 [00:13<00:00, 362.18it/s]
plt.plot(losses);
Implementing the ODE integrator¶
Now that we’ve learned , we can use it in the probability flow ODE
to get our generative model. To solve this ODE, we will use the Euler method, and iterate on the procedure:
We will do this with the following class called ODEIntegrator
which implements the iterative step and performs the rollout over the interval .
class ODE:
def __init__(self, b, interpolant, n_step):
self.b = b
self.interpolant = interpolant
self.n_step = n_step
self.ts = torch.linspace(0.0,1.0, n_step + 1).to(device)
self.dt = self.ts[1] - self.ts[0]
def step(self, x, t):
return x + self.b(x, t)*self.dt
def solve(self, x_init):
bs = x_init.shape[0]
xs = torch.zeros((self.n_step, *x_init.shape))
x = x_init
for i,t in enumerate(self.ts[:-1]):
t = t.repeat(len(x))
x = self.step(x,t)
xs[i] = x
return xs
ode = ODE(b, interpolant, n_step = 80)
x_init = base.sample(20000).to(device)
xfs = ode.solve(x_init)
x1s = grab(xfs[-1])
Implementing the SDE integrator¶
We can also use the velocity field in the generative model based on the SDE
where is the score and is the diffusion coefficient. To get we can use its relation to :
Since , to avoid problems near , below we will use for some constant , and calculate directly instead of .
To solve the SDE, we will use the Euler-Maruyama method, and iterate on the procedure:
where .
We will do this with the following class called SDEIntegrator
which implements the iterative step and performs the rollout over the interval .
class SDE:
def __init__(self, b, interpolant, eps, n_step):
self.b = b
self.interpolant = interpolant
self.n_step = n_step
self.ts = torch.linspace(0.0,0.99, n_step + 1).to(device)
self.dt = self.ts[1] - self.ts[0]
self.eps = eps.to(device)
self.sqrtepsdt = torch.sqrt(torch.tensor(2.0)*self.eps*self.dt).to(device)
def alphascore(self, x, t):
alpha = self.interpolant.alpha(t).unsqueeze(-1)
dotalpha = self.interpolant.dotalpha(t).unsqueeze(-1)
beta = self.interpolant.beta(t).unsqueeze(-1)
dotbeta = self.interpolant.dotbeta(t).unsqueeze(-1)
return (beta*self.b(x, t)-dotbeta*x)/(alpha*dotbeta-dotalpha*beta)
def step(self, x, t):
alpha = self.interpolant.alpha(t).unsqueeze(-1)
return x + (self.b(x, t)+self.eps*self.alphascore(x, t))*self.dt + self.sqrtepsdt*torch.sqrt(alpha)*torch.randn(x.size()).to(device)
def solve(self, x_init):
bs = x_init.shape[0]
xs = torch.zeros((self.n_step, *x_init.shape))
x = x_init
for i,t in enumerate(self.ts[:-1]):
t = t.repeat(len(x))
x = self.step(x,t)
xs[i] = x
return xs
sde = SDE(b, interpolant, eps = torch.tensor(2.0), n_step = 80)
# x_init = base.sample(20000)
xf2s = sde.solve(x_init).to(device)
x2s = grab(xf2s[-1])
Now let’s plot the results of our generative models:
c = '#62508f' # plot color
fig, axes = plt.subplots(1,2, figsize=(10,5))
axes[0].scatter(x1s[:,0], x1s[:,1], alpha = 0.03, c = c)
axes[0].set_title(r"Samples from the prob. flow ODE", fontsize = 18)
axes[0].set_xticks([-4,0,4]), axes[0].set_yticks([-4,0,4]);
axes[1].scatter(x2s[:,0], x2s[:,1], alpha = 0.03, c = c)
axes[1].set_title(r"Samples from the SDE", fontsize = 18)
axes[1].set_xticks([-4,0,4]), axes[1].set_yticks([-4,0,4]);
plt.tight_layout();
What’s next?¶
Now that we’ve built out the machinery for learning a dynamical generative model for a simple 2D distribution, we can turn to ask about image generation, which for coding purposes just means replacing our feed forward neural network with a much better one (that we’ll see on the next page), and loading a different data source. Let’s proceeed to that.
- Ho, J., Jain, A., & Abbeel, P. (2020). Denoising diffusion probabilistic models. Advances in Neural Information Processing Systems, 33, 6840–6851.
- Song, Y., Sohl-Dickstein, J., Kingma, D. P., Kumar, A., Ermon, S., & Poole, B. (2020). Score-based generative modeling through stochastic differential equations. arXiv Preprint arXiv:2011.13456.
- Albergo, M. S., & Vanden-Eijnden, E. (2024). Learning to sample better. Journal of Statistical Mechanics: Theory and Experiment, 2024(10), 104014. 10.1088/1742-5468/ad363c
- Albergo, M. S., & Vanden-Eijnden, E. (2022). Building Normalizing Flows with Stochastic Interpolants. The Eleventh International Conference on Learning Representations.
- Lipman, Y., Chen, R. T., Ben-Hamu, H., Nickel, M., & Le, M. (2022). Flow Matching for Generative Modeling. The Eleventh International Conference on Learning Representations.