Pytorch implementation of the NeurIPS 2022 paper Poisson Flow Generative Models,
by Yilun Xu*, Ziming Liu*, Max Tegmark, Tommi S. Jaakkola
We propose a new Poisson flow generative model (PFGM) that maps a uniform distribution on a high-dimensional hemisphere into any data distribution. We interpret the data points as electrical charges on the
Experimentally, PFGM achieves current state-of-the-art performance among the normalizing flow models on CIFAR-10, with an Inception score of 9.68 and a FID score of 2.35. It also performs on par with the state-of-the-art SDE approaches (e.g., score-based SDEs or Diffusion models) while offering 10x to 20x acceleration on image generation tasks. Additionally, PFGM appears more tolerant of estimation errors on a weaker network architecture and robust to the step size in the Euler method, and capable of scale-up to higher resolution datasets.
Acknowledgement: Our implementation relies on the repo https://github.com/yang-song/score_sde_pytorch.
We provide two solutions to install a subset of necessary python packages for our code. Please find the best fit for you.
- The old dependency in repo https://github.com/yang-song/score_sde_pytorch
pip install -r requirements_old.txt
- Our dependency (Python 3.9.12, CUDA Version 11.6)
pip install -r requirements.txt
Train and evaluate our models through main.py
.
python3 main.py:
--config: Training configuration.
--eval_folder: The folder name for storing evaluation results
(default: 'eval')
--mode: <train|eval>: Running mode: train or eval
--workdir: Working directory
For example, to train a new PFGM w/ DDPM++ model on CIFAR-10 dataset, one could execute
python3 main.py --config ./configs/poisson/cifar10_ddpmpp.py --mode train \
--workdir poisson_ddpmpp
-
config
is the path to the config file. The prescribed config files are provided inconfigs/
. They are formatted according toml_collections
and should be quite self-explanatory.Naming conventions of config files: the path of a config file is a combination of the following dimensions:
- Method: 🌟PFGM:
poisson
; Score-based models :ve
,vp
,sub_vp
- dataset: One of
cifar10
,celeba64
,celebahq_256
,ffhq_256
,celebahq
,ffhq
. - model: One of
ncsnpp
,ddpmpp
. - continuous: train the model with continuously sampled time steps (only for score-based models).
🌟Important Note 1 : We use a large batch (e.g. current
training.batch_size=4096
for CIFAR-10, ~25G GPU memory usage) to calculate the Poisson field for each mini-batch samples (e.g.training.small_batch_size=128
for CIFAR-10). To adjust GPU memory cost, please modify thetraining.batch_size
parameter in the config files.🌟Important Note 2 : If
rk45
solver exibits unstability for your dataset/neural network, please try to use the forward Euler method or Improved Euler method by modifying theconfig.sampling.ode_solver
parameter toforward_euler
orimproved_euler
.Please set some key hyper-parameters for specific dataset by running
python3 hyper-parameters.py --data_norm: Average data norm of the dataset --data_dim: Data dimension
We also list a few other useful tips in Tips section.
- Method: 🌟PFGM:
-
workdir
is the path that stores all artifacts of one experiment, like checkpoints, samples, and evaluation results. -
eval_folder
is the name of a subfolder inworkdir
that stores all artifacts of the evaluation process, like meta checkpoints for pre-emption prevention, image samples, and numpy dumps of quantitative results. -
mode
is either "train" or "eval". When set to "train", it starts the training of a new model, or resumes the training of an old model if its meta-checkpoints (for resuming running after pre-emption in a cloud environment) exist inworkdir/checkpoints-meta
. -
Below are the list of evalutation command-line flags:
--config.eval.enable_sampling
: Generate samples and evaluate sample quality, measured by FID and Inception score.--config.eval.enable_bpd
: Compute log-likelihoods--config.eval.dataset=train/test
: Indicate whether to compute the likelihoods on the training or test dataset.--config.eval.enable_interpolate
: Image Interpolation--config.eval.enable_rescale
: Temperature scaling
-
🌟Important : We use a large batch (e.g. current
training.batch_size=4096
for CIFAR-10, ~25G GPU memory usage) to calculate the Poisson field for each mini-batch samples (e.g.training.small_batch_size=128
for CIFAR-10). To adjust GPU memory cost, please modify thetraining.batch_size
parameter in the config files. -
🌟 How to set the hyper-parameters : The prior distribution on the
$z=z_{max}$ hyperplane is a long-tail distribution. We recommend clipping the sample norm by the hyper-parameterssampling.upper_norm
. Please refer toAppendix B.1.1
andAppendix B.2.1
in the paper (https://arxiv.org/abs/2209.11178) for our recommended setups for hyper-parameterstraining.M
,sampling.z_max
andsampling.upper_norm
for general datasets.We provide a script for easily calculating those hyper-parameters:
python3 hyper-parameters.py --data_norm: Average data norm of the dataset --data_dim: Data dimension
-
If
rk45
solver exibits unstability for your dataset/neural network, please try to use the forward Euler method or Improved Euler method by modifying theconfig.sampling.ode_solver
parameter toforward_euler
orimproved_euler
. -
TODO
Please place the pretrained checkpoints under the directory workdir/checkpoints
, e.g., cifar10_ddpmpp/checkpoints
.
To generate and evaluate the FID/IS of (10k) samples of the PFGM w/ DDPM++ model, you could execute:
python3 main.py --config ./configs/poisson/cifar10_ddpmpp.py --mode eval \
--workdir cifar10_ddpmpp --config.eval.enable_sampling --config.eval.num_samples 10000
To only generate and visualize 100 samples of the PFGM w/ DDPM++ model, you could execute:
python3 main.py --config ./configs/poisson/cifar10_ddpmpp.py --mode eval \
--workdir cifar10_ddpmpp --config.eval.enable_sampling --config.eval.save_images --config.eval.batch_size 100
The samples will be saved to cifar10_ddpmpp/eval/ode_images_{ckpt}.png
.
All checkpoints are provided in this Google drive folder.
Dataset | Checkpoint path | Invertible? | IS | FID | NFE (RK45) |
---|---|---|---|---|---|
CIFAR-10 | poisson/cifar10_ddpmpp/ |
✔️ | 9.65 | 2.48 | ~104 |
CIFAR-10 | poisson/cifar10_ddpmpp_deep/ |
✔️ | 9.68 | 2.35 | ~110 |
LSUN bedroom |
poisson/bedroom_ddpmpp/ |
✔️ | - | 13.66 | ~122 |
CelebA |
poisson/celeba_ddpmpp/ |
✔️ | - | 3.68 | ~110 |
Please find the statistics for FID scores in the following links: