forked from tianweiy/DMD2
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add one-step sdxl training instructions
- Loading branch information
Showing
6 changed files
with
120 additions
and
11 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
41 changes: 41 additions & 0 deletions
41
...ents/sdxl/sdxl_cond399_8node_lr5e-7_1step_diffusion1000_gan5e-3_guidance8_noinit_noode.sh
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
export CHECKPOINT_PATH=$1 | ||
export WANDB_ENTITY=$2 | ||
export WANDB_PROJECT=$3 | ||
export FSDP_DIR=$4 | ||
export RANK=$5 | ||
|
||
# accelerate launch --config_file fsdp_configs/fsdp_1node_debug.yaml main/train_sd.py \ | ||
accelerate launch --config_file $FSDP_DIR/config_rank$RANK.yaml main/train_sd.py \ | ||
--generator_lr 5e-7 \ | ||
--guidance_lr 5e-7 \ | ||
--train_iters 100000000 \ | ||
--output_path $CHECKPOINT_PATH/sdxl_cond399_8node_lr5e-7_1step_diffusion1000_gan5e-3_guidance8_noinit_noode \ | ||
--batch_size 2 \ | ||
--grid_size 2 \ | ||
--initialie_generator --log_iters 1000 \ | ||
--resolution 1024 \ | ||
--latent_resolution 128 \ | ||
--seed 10 \ | ||
--real_guidance_scale 8 \ | ||
--fake_guidance_scale 1.0 \ | ||
--max_grad_norm 10.0 \ | ||
--model_id "stabilityai/stable-diffusion-xl-base-1.0" \ | ||
--wandb_iters 100 \ | ||
--wandb_entity $WANDB_ENTITY \ | ||
--wandb_name "sdxl_cond399_8node_lr5e-7_1step_diffusion1000_gan5e-3_guidance8_noinit_noode" \ | ||
--log_loss \ | ||
--dfake_gen_update_ratio 5 \ | ||
--fsdp \ | ||
--sdxl \ | ||
--use_fp16 \ | ||
--max_step_percent 0.98 \ | ||
--cls_on_clean_image \ | ||
--gen_cls_loss \ | ||
--gen_cls_loss_weight 5e-3 \ | ||
--guidance_cls_loss_weight 1e-2 \ | ||
--diffusion_gan \ | ||
--diffusion_gan_max_timestep 1000 \ | ||
--conditioning_timestep 399 \ | ||
--train_prompt_path $CHECKPOINT_PATH/captions_laion_score6.25.pkl \ | ||
--real_image_path $CHECKPOINT_PATH/sdxl_vae_latents_laion_500k_lmdb/ \ | ||
--generator_ckpt_path $CHECKPOINT_PATH/sdxl_lr1e-5_8node_ode_pretraining_10k_cond399_checkpoint_model_002000.bin |
27 changes: 27 additions & 0 deletions
27
experiments/sdxl/sdxl_lr1e-5_8node_ode_pretraining_10k_cond399.sh
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
export CHECKPOINT_PATH=$1 | ||
export WANDB_ENTITY=$2 | ||
export WANDB_PROJECT=$3 | ||
export MASTER_IP=$4 | ||
|
||
torchrun --nnodes 8 --nproc_per_node=8 --rdzv_id=2345 \ | ||
--rdzv_backend=c10d \ | ||
--rdzv_endpoint=$MASTER_IP main/train_sd_ode.py \ | ||
--generator_lr 1e-5 \ | ||
--train_iters 100000000 \ | ||
--output_path $CHECKPOINT_PATH/sdxl_lr1e-5_8node_ode_pretraining_10k_cond399 \ | ||
--grid_size 1 \ | ||
--log_iters 1000 \ | ||
--resolution 1024 \ | ||
--seed 10 \ | ||
--max_grad_norm 10.0 \ | ||
--model_id "stabilityai/stable-diffusion-xl-base-1.0" \ | ||
--wandb_iters 250 \ | ||
--wandb_entity tyin \ | ||
--wandb_name "sdxl_lr1e-5_8node_ode_pretraining_10k_cond399" \ | ||
--sdxl \ | ||
--num_ode_pairs 10000 \ | ||
--ode_pair_path $CHECKPOINT_PATH/laion6.25_pair_generation_sdxl_guidance6_full_lmdb/ \ | ||
--ode_batch_size 4 \ | ||
--conditioning_timestep 399 \ | ||
--tiny_vae \ | ||
--use_fp16 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
$CHECKPOINT_PATH=$1 | ||
|
||
wget https://huggingface.co/tianweiy/DMD2/resolve/main/model/sdxl/sdxl_lr1e-5_8node_ode_pretraining_10k_cond399_checkpoint_model_002000.bin?download=true -O $CHECKPOINT_PATH/sdxl_lr1e-5_8node_ode_pretraining_10k_cond399_checkpoint_model_002000.bin |