Skip to content

yuanchenyang/smalldiffusion

Repository files navigation

smalldiffusion

Tutorial blog post Paper link Open in Colab Pypi project Build Status

A lightweight diffusion library for training and sampling from diffusion models. It is built for easy experimentation when training new models and developing new samplers, supporting minimal toy models to state-of-the-art pretrained models. The core of this library for diffusion training and sampling is implemented in less than 100 lines of very readable pytorch code. To install from pypi:

pip install smalldiffusion

Toy models

To train and sample from the Swissroll toy dataset in 10 lines of code (see examples/toyexample.ipynb for a detailed guide):

from torch.utils.data import DataLoader
from smalldiffusion import Swissroll, TimeInputMLP, ScheduleLogLinear, training_loop, samples

dataset  = Swissroll(np.pi/2, 5*np.pi, 100)
loader   = DataLoader(dataset, batch_size=2048)
model    = TimeInputMLP(hidden_dims=(16,128,128,128,128,16))
schedule = ScheduleLogLinear(N=200, sigma_min=0.005, sigma_max=10)
trainer  = training_loop(loader, model, schedule, epochs=15000)
losses   = [ns.loss.item() for ns in trainer]
*xt, x0  = samples(model, schedule.sample_sigmas(20), gam=2)

Results on various toy datasets:

Diffusion Transformer

We provide a concise implementation of the diffusion transformer introduced in [Peebles and Xie 2022]. To train a model on the FashionMNIST dataset and generate a batch of samples (after first running accelerate config):

accelerate launch examples/fashion_mnist_dit.py

With the provided default parameters and training on a single GPU for around 2 hours, the model can achieve a FID score of around 5-6, producing the following generated outputs:

U-Net models

The same code can be used to train U-Net-based models.

accelerate launch examples/fashion_mnist_unet.py

StableDiffusion

smalldiffusion's sampler works with any pretrained diffusion model, and supports DDPM, DDIM as well as accelerated sampling algorithms. In examples/diffusers_wrapper.py, we provide a simple wrapper for any pretrained huggingface diffusers latent diffusion model, enabling sampling from pretrained models with only a few lines of code:

from diffusers_wrapper import ModelLatentDiffusion
from smalldiffusion import ScheduleLDM, samples

schedule = ScheduleLDM(1000)
model    = ModelLatentDiffusion('stabilityai/stable-diffusion-2-1-base')
model.set_text_condition('An astronaut riding a horse')
*xts, x0 = samples(model, schedule.sample_sigmas(50))
decoded  = model.decode_latents(x0)

It is easy to experiment with different sampler parameters and sampling schedules, as demonstrated in examples/stablediffusion.py. A few examples on tweaking the parameter gam:

How to use

The core of smalldiffusion depends on the interaction between data, model and schedule objects. Here we give a specification of these objects. For a detailed introduction to diffusion models and the notation used in the code, see the accompanying tutorial.

Data

For training diffusion models, smalldiffusion supports pytorch Datasets and DataLoaders. The training code expects the iterates from a DataLoader object to be batches of data, without labels. To remove labels from existing datasets, extract the data with the provided MappedDataset wrapper before constructing a DataLoader.

Two toy datasets, Swissroll and DatasaurusDozen, are provided.

Model

All model objects should be a subclass of torch.nn.Module. Models should have:

  • A parameter input_dims, a tuple containing the dimensions of the input to the model (not including batch-size).
  • A method rand_input(batchsize) which takes in a batch-size and returns an i.i.d. standard normal random input with shape [batchsize, *input_dims]. This method can be inherited from the provided ModelMixin class when the input_dims parameter is set.

Models are called with two arguments:

  • x is a batch of data of batch-size B and shape [B, *model.input_dims].
  • sigma is either a singleton or a batch.
    1. If sigma.shape == [], the same value will be used for each x.
    2. Otherwise sigma.shape == [B, 1, ..., 1], and x[i] will be paired with sigma[i].

Models should return a predicted noise value with the same shape as x.

Schedule

A Schedule object determines the rate at which the noise level sigma increases during the diffusion process. It is constructed by simply passing in a tensor of increasing sigma values. Schedule objects have the methods

  • sample_sigmas(steps) which subsamples the schedule for sampling.
  • sample_batch(batchsize) which generates batch of sigma values selected uniformly at random, for use in training.

Three schedules are provided:

  1. ScheduleLogLinear is a simple schedule which works well on small datasets and toy models.
  2. ScheduleDDPM is commonly used in pixel-space image diffusion models.
  3. ScheduleLDM is commonly used in latent diffusion models, e.g. StableDiffusion.

The following plot shows these three schedules with default parameters.

Training

The training_loop generator function provides a simple training loop for training a diffusion model , given loader, model and schedule objects described above. It yields a namespace with the local variables, for easy evaluation during training. For example, to print out the loss every iteration:

for ns in training_loop(loader, model, schedule):
    print(ns.loss.item())

Multi-GPU training and sampling is also supported via accelerate.

Sampling

To sample from a diffusion model, the samples generator function takes in a model and a decreasing list of sigmas to use during sampling. This list is usually created by calling the sample_sigmas(steps) method of a Schedule object. The generator will yield a sequence of xts produced during sampling. The sampling loop generalizes most commonly-used samplers:

For more details on how these sampling algorithms can be simplified, generalized and implemented in only 5 lines of code, see Appendix A of [Permenter and Yuan].

About

Simple and readable code for training and sampling from diffusion models

Topics

Resources

License

Stars

Watchers

Forks

Packages

No packages published