Skip to content

Commit

Permalink
add one-step sdxl training instructions
Browse files Browse the repository at this point in the history
  • Loading branch information
tianweiy committed May 28, 2024
1 parent 01f37a8 commit 798a1ec
Show file tree
Hide file tree
Showing 6 changed files with 120 additions and 11 deletions.
12 changes: 6 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -57,13 +57,13 @@ python setup.py develop

## Inference Example

To reproduce our ImageNet results, run:
#### ImageNet

```.bash
python -m demo.imagenet_example --checkpoint_path IMAGENET_CKPT_PATH
```

To try our text-to-image generation demo, run:
#### Text-to-Image

```.bash
# Note: on the demo page, click ``Use Tiny VAE for faster decoding'' to enable much faster speed and lower memory consumption using a Tiny VAE from [madebyollin](https://huggingface.co/madebyollin/taesdxl)
Expand All @@ -77,7 +77,7 @@ python -m demo.text_to_image_sdxl --num_step 1 --checkpoint_path SDXL_CKPT_PATH

We can also use the standard diffuser pipeline:

4-step generation
#### 4-step generation

```.bash
import torch
Expand All @@ -86,7 +86,7 @@ from huggingface_hub import hf_hub_download
from safetensors.torch import load_file
base_model_id = "stabilityai/stable-diffusion-xl-base-1.0"
repo_name = "tianweiy/DMD2"
ckpt_name = "dmd2_sdxl_4step_unet.bin"
ckpt_name = "dmd2_sdxl_4step_unet_fp16.bin"
# Load model.
unet = UNet2DConditionModel.from_config(base_model_id, subfolder="unet").to("cuda", torch.float16)
unet.load_state_dict(torch.load(hf_hub_download(repo_name, ckpt_name), map_location="cuda"))
Expand All @@ -96,7 +96,7 @@ prompt="a photo of a cat"
image=pipe(prompt=prompt, num_inference_steps=4, guidance_scale=0).images[0]
```

1-step generation
#### 1-step generation

```.bash
import torch
Expand All @@ -105,7 +105,7 @@ from huggingface_hub import hf_hub_download
from safetensors.torch import load_file
base_model_id = "stabilityai/stable-diffusion-xl-base-1.0"
repo_name = "tianweiy/DMD2"
ckpt_name = "dmd2_sdxl_1step_unet.bin"
ckpt_name = "dmd2_sdxl_1step_unet_fp16.bin"
# Load model.
unet = UNet2DConditionModel.from_config(base_model_id, subfolder="unet").to("cuda", torch.float16)
unet.load_state_dict(torch.load(hf_hub_download(repo_name, ckpt_name), map_location="cuda"))
Expand Down
2 changes: 1 addition & 1 deletion experiments/imagenet/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ We trained ImageNet using mixed-precision in BF16 format, adapting the EDM's cod
| ----------- | --- | ---- | ----- | ----- |
| [imagenet_gan_classifier_genloss3e-3_diffusion1000_lr2e-6_scratch](./imagenet_gan_classifier_genloss3e-3_diffusion1000_lr2e-6_scratch.sh) | 1.51 | [link](https://huggingface.co/tianweiy/DMD2/tree/main/model/imagenet/imagenet_gan_classifier_genloss3e-3_diffusion1000_lr2e-6_scratch_fid1.51_checkpoint_model_193500/) | 200k | 53 |
| [imagenet_lr2e-6_scratch](./imagenet_lr2e-6_scratch.sh) | 2.61 | [link](https://huggingface.co/tianweiy/DMD2/tree/main/model/imagenet/imagenet_lr2e-6_scratch_fid2.61_checkpoint_model_405500/) | 410k | 70 |
| [imagenet_gan_classifier_genloss3e-3_diffusion1000_lr5e-7_resume*](./imagenet_gan_classifier_genloss3e-3_diffusion1000_lr5e-7_resume.sh) | 1.28 | [link](https://huggingface.co/tianweiy/DMD2/tree/main/model/imagenet/imagenet_gan_classifier_genloss3e-3_diffusion1000_lr5e-7_resume_fid1.28_checkpoint_model_548000/) | 140K | 70 |
| [imagenet_gan_classifier_genloss3e-3_diffusion1000_lr5e-7_resume*](./imagenet_gan_classifier_genloss3e-3_diffusion1000_lr5e-7_resume.sh) | 1.28 | [link](https://huggingface.co/tianweiy/DMD2/tree/main/model/imagenet/imagenet_gan_classifier_genloss3e-3_diffusion1000_lr5e-7_resume_fid1.28_checkpoint_model_548000/) | 140K | 38 |

*The final model was resumed from the best checkpoint of the **imagenet_lr2e-6_scratch** run and trained for an additional 140,000 iterations.

Expand Down
46 changes: 42 additions & 4 deletions experiments/sdxl/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,7 @@
| Config Name | FID | Link | Iters | Hours |
| ----------- | --- | ---- | ----- | ----- |
| [sdxl_cond999_8node_lr5e-7_denoising4step_diffusion1000_gan5e-3_guidance8_noinit_noode_backsim_scratch](./sdxl_cond999_8node_lr5e-7_denoising4step_diffusion1000_gan5e-3_guidance8_noinit_noode_backsim_scratch.sh) | 19.32 | [link](https://huggingface.co/tianweiy/DMD2/tree/main/model/sdxl/sdxl_cond999_8node_lr5e-7_denoising4step_diffusion1000_gan5e-3_guidance8_noinit_noode_backsim_scratch_checkpoint_model_019000) | 19k | 57 |
| [sdxl_cond399_8node_lr5e-7_1step_diffusion1000_gan5e-3_guidance8_noinit_noode](./laion6.25_sd_baseline_8node_guidance1.75_lr5e-7_seed10_dfake10_diffusion1000_gan1e-3_resume.sh) | 19.01 | [link](https://huggingface.co/tianweiy/DMD2/tree/main/model/sdxl/sdxl_cond399_8node_lr5e-7_1step_diffusion1000_gan5e-3_guidance8_noinit_noode_checkpoint_model_024000) | TBD | TBD |

1-step model training needs some special handling, we will support it soon.
| [sdxl_cond399_8node_lr5e-7_1step_diffusion1000_gan5e-3_guidance8_noinit_noode](./) | 19.01 | [link](https://huggingface.co/tianweiy/DMD2/tree/main/model/sdxl/sdxl_cond399_8node_lr5e-7_1step_diffusion1000_gan5e-3_guidance8_noinit_noode_checkpoint_model_024000) | 24k | 57 |

For inference with our models, you only need to download the pytorch_model.bin file from the provided link. For fine-tuning, you will need to download the entire folder.
You can use the following script for that:
Expand Down Expand Up @@ -40,7 +38,7 @@ bash scripts/download_sdxl.sh $CHECKPOINT_PATH

You can also add these few export to the bashrc file so that you don't need to run them every time you open a new terminal.

### Sample Training/Testing Commands
### 4-step Sample Training/Testing Commands

```bash
# start a training with 64 gpu. we need to run this script on all 8 nodes. Please change the EXP_NAME and NODE_RANK_ID accordingly.
Expand All @@ -58,4 +56,44 @@ python main/sdxl/test_folder_sdxl.py \
--wandb_name test_sdxl_cond999_8node_lr5e-7_denoising4step_diffusion1000_gan5e-3_guidance8_noinit_noode_backsim_scratch
```

### 1-step Sample Training/Testing Commands [Work In Progress]

For 1-step model, we need an extra regression loss pretraining.

First, download the 10K noise-image pairs

```bash
bash scripts
```

Second, Pretrain the model with regression loss

```bash
bash experiments/sdxl/sdxl_lr1e-5_8node_ode_pretraining_10k_cond399.sh $CHECKPOINT_PATH $WANDB_ENTITY $WANDB_PROJECT $MASTER_IP
```

Alternatively, you can skip the previous two steps and directly download the regression loss pretrained checkpoint

```bash
bash scripts/download_sdxl_1step_ode_pairs.sh $CHECKPOINT_PATH
```

Start the real training

```bash
# start a training with 64 gpu. we need to run this script on all 8 nodes. Please change the EXP_NAME and NODE_RANK_ID accordingly.
bash experiments/sdxl/sdxl_cond399_8node_lr5e-7_1step_diffusion1000_gan5e-3_guidance8_noinit_noode.sh $CHECKPOINT_PATH $WANDB_ENTITY $WANDB_PROJECT fsdp_configs/EXP_NAME NODE_RANK_ID

# on some other machine, start a testing process that continually reads from the checkpoint folder and evaluate the FID
# Change TIMESTAMP_TBD to the real one
python main/sdxl/test_folder_sdxl.py \
--folder $CHECKPOINT_PATH/sdxl_cond399_8node_lr5e-7_1step_diffusion1000_gan5e-3_guidance8_noinit_noode/TIMESTAMP_TBD/ \
--conditioning_timestep 399 --num_step 1 --wandb_entity $WANDB_ENTITY \
--wandb_project $WANDB_PROJECT --num_train_timesteps 1000 \
--seed 10 --eval_res 512 --ref_dir $CHECKPOINT_PATH/coco10k/subset \
--anno_path $CHECKPOINT_PATH/coco10k/all_prompts.pkl \
--total_eval_samples 10000 --clip_score \
--wandb_name test_sdxl_cond399_8node_lr5e-7_1step_diffusion1000_gan5e-3_guidance8_noinit_noode
```

Please refer to [train_sd.py](../../main/train_sd.py) for various training options. Notably, if the `--delete_ckpts` flag is set to `True`, all checkpoints except the latest one will be deleted during training. Additionally, you can use the `--cache_dir` flag to specify a location with larger storage capacity. The number of checkpoints stored in `cache_dir` is controlled by the `max_checkpoint` argument.
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
export CHECKPOINT_PATH=$1
export WANDB_ENTITY=$2
export WANDB_PROJECT=$3
export FSDP_DIR=$4
export RANK=$5

# accelerate launch --config_file fsdp_configs/fsdp_1node_debug.yaml main/train_sd.py \
accelerate launch --config_file $FSDP_DIR/config_rank$RANK.yaml main/train_sd.py \
--generator_lr 5e-7 \
--guidance_lr 5e-7 \
--train_iters 100000000 \
--output_path $CHECKPOINT_PATH/sdxl_cond399_8node_lr5e-7_1step_diffusion1000_gan5e-3_guidance8_noinit_noode \
--batch_size 2 \
--grid_size 2 \
--initialie_generator --log_iters 1000 \
--resolution 1024 \
--latent_resolution 128 \
--seed 10 \
--real_guidance_scale 8 \
--fake_guidance_scale 1.0 \
--max_grad_norm 10.0 \
--model_id "stabilityai/stable-diffusion-xl-base-1.0" \
--wandb_iters 100 \
--wandb_entity $WANDB_ENTITY \
--wandb_name "sdxl_cond399_8node_lr5e-7_1step_diffusion1000_gan5e-3_guidance8_noinit_noode" \
--log_loss \
--dfake_gen_update_ratio 5 \
--fsdp \
--sdxl \
--use_fp16 \
--max_step_percent 0.98 \
--cls_on_clean_image \
--gen_cls_loss \
--gen_cls_loss_weight 5e-3 \
--guidance_cls_loss_weight 1e-2 \
--diffusion_gan \
--diffusion_gan_max_timestep 1000 \
--conditioning_timestep 399 \
--train_prompt_path $CHECKPOINT_PATH/captions_laion_score6.25.pkl \
--real_image_path $CHECKPOINT_PATH/sdxl_vae_latents_laion_500k_lmdb/ \
--generator_ckpt_path $CHECKPOINT_PATH/sdxl_lr1e-5_8node_ode_pretraining_10k_cond399_checkpoint_model_002000.bin
27 changes: 27 additions & 0 deletions experiments/sdxl/sdxl_lr1e-5_8node_ode_pretraining_10k_cond399.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
export CHECKPOINT_PATH=$1
export WANDB_ENTITY=$2
export WANDB_PROJECT=$3
export MASTER_IP=$4

torchrun --nnodes 8 --nproc_per_node=8 --rdzv_id=2345 \
--rdzv_backend=c10d \
--rdzv_endpoint=$MASTER_IP main/train_sd_ode.py \
--generator_lr 1e-5 \
--train_iters 100000000 \
--output_path $CHECKPOINT_PATH/sdxl_lr1e-5_8node_ode_pretraining_10k_cond399 \
--grid_size 1 \
--log_iters 1000 \
--resolution 1024 \
--seed 10 \
--max_grad_norm 10.0 \
--model_id "stabilityai/stable-diffusion-xl-base-1.0" \
--wandb_iters 250 \
--wandb_entity tyin \
--wandb_name "sdxl_lr1e-5_8node_ode_pretraining_10k_cond399" \
--sdxl \
--num_ode_pairs 10000 \
--ode_pair_path $CHECKPOINT_PATH/laion6.25_pair_generation_sdxl_guidance6_full_lmdb/ \
--ode_batch_size 4 \
--conditioning_timestep 399 \
--tiny_vae \
--use_fp16
3 changes: 3 additions & 0 deletions scripts/download_sdxl_1step_ode_pairs.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
$CHECKPOINT_PATH=$1

wget https://huggingface.co/tianweiy/DMD2/resolve/main/model/sdxl/sdxl_lr1e-5_8node_ode_pretraining_10k_cond399_checkpoint_model_002000.bin?download=true -O $CHECKPOINT_PATH/sdxl_lr1e-5_8node_ode_pretraining_10k_cond399_checkpoint_model_002000.bin

0 comments on commit 798a1ec

Please sign in to comment.