Skip to content

Latest commit

 

History

History
86 lines (58 loc) · 5.21 KB

README.md

File metadata and controls

86 lines (58 loc) · 5.21 KB

Data Generation

The generation codes are modifed based on the official implementation of EDM and official implementation of Carmon et al., 2019. We employ the class-conditional EDM in this implementation.

Requirements

  • This project is tested with Ubuntu 20.04.3.
  • 4 NVIDIA A100 SXM4 40GB GPUs for training and image generation.
  • 64-bit Python 3.8 and PyTorch 1.12.0 (or later). See https://pytorch.org for PyTorch install instructions.
  • Python libraries: See environment.yml for exact library dependencies. You can use the following commands with Miniconda3 to create and activate your Python environment:
    • conda env create -f environment.yml -n edm
    • conda activate edm
  • For 1M data generation, we use the official implementation of Carmon et al., 2019 to train WRN-28-10 models to give pseudo-labels, following Rebuffi et al., 2021. Download selection models to ./selection_model.
dataset clean link
CIFAR-10 96.15% cifar10_pseudo.pt
CIFAR-100 80.47% cifar100_pseudo.pt

Generating data for CIFAR-10

For CIFAR-10, we generate images using the pre-trained model provided by EDM, which yields a new state-of-the-art FID of 1.79.

For 1M data generation, following Rebuffi et al., 2021, we first generate 500K images for each class and 5M in total. Generating a large number of images can be time-consuming; the workload can be distributed across multiple GPUs by launching the above command using torchrun:

# Generate 500K images for each class using 4 A100 GPUs, using deterministic sampling with 20 steps
torchrun --standalone --nproc_per_node=4 generate.py --outdir=out_cifar10 --seeds=0-499999 --batch=2048  --steps=20 --class_num=10 \
    --network=https://nvlabs-fi-cdn.nvidia.com/edm/pretrained/edm-cifar10-32x32-cond-vp.pkl

The name of .npy file indicates the label of images in the file, e.g., 1.npy. We use the pre-trained WRN-28-10 model to score each image and select the top 100K images for each class:

python select_1M.py --model_path ./selection_model/cifar10_pseudo.pt --data_dir ./out_cifar10 --output_dir ./npz_cifar10 --class_num 10

When the amount of required generated data exceeds 1M, we merge .npy data files without selection. For example, generate 5M data:

python combine_data.py --data_dir ./out_cifar10 --output_dir ./npz_cifar10 --class_num 10 --file_name 5m

Generating data for CIFAR-100

For CIFAR-100, we train our own model on four A100 GPUs and select the model with the best FID (2.09) after 25 sampling steps. For 1M data generation, we first generate 50K images for each class and 5M in total:

# Generate 50K images for each class using 4 A100 GPUs, using deterministic sampling with 25 steps
torchrun --standalone --nproc_per_node=4 generate.py --outdir=out_cifar100 --seeds=0-49999 --batch=2048  --steps=25 --class_num=100 \
    --network=https://huggingface.co/wzekai99/DM-Improves-AT/resolve/main/others/edm-cifar100-32x32-cond-vp.pkl

We use the pre-trained WRN-28-10 model to score each image and select the top 10K images for each class:

python select_1M.py --model_path ./selection_model/cifar100_pseudo.pt --data_dir ./out_cifar100 --output_dir ./npz_cifar100 --class_num 100

When the amount of required generated data exceeds 1M, we merge .npy data files without selection. For example, generate 5M data:

python combine_data.py --data_dir ./out_cifar100 --output_dir ./npz_cifar100 --class_num 100 --file_name 5m

Generating data for SVHN

For SVHN, we train our own model on four A100 GPUs and select the model with the best FID (1.39) after 25 sampling steps. Unlike CIFAR-10/CIFAR-100 generation, the data for SVHN is generated without selection by a pre-trained model. For instance, for 1M data generation, we first generate 10K images for each class and 1M in total:

# Generate 50K images for each class using 4 A100 GPUs, using deterministic sampling with 25 steps
torchrun --standalone --nproc_per_node=4 generate.py --outdir=out_svhn --seeds=0-99999 --batch=2048  --steps=25 --class_num=10 \
    --network=https://huggingface.co/wzekai99/DM-Improves-AT/resolve/main/others/edm-svhn-32x32-cond-vp.pkl

Then we merge .npy data files without selection:

python combine_data.py --data_dir ./out_svhn --output_dir ./npz_svhn --class_num 10 --file_name 1m

License

Source code and pre-trained models of EDM are licensed under the Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International License. Official implementation of Carmon et al., 2019 is originally shared under the MIT license.