Skip to content

Commit

Permalink
Fix EMA in train_text_to_image_sdxl.py (huggingface#7048)
Browse files Browse the repository at this point in the history
* Fix typos
  • Loading branch information
tolgacangoz authored Feb 26, 2024
1 parent d603ccb commit ad310af
Showing 1 changed file with 5 additions and 0 deletions.
5 changes: 5 additions & 0 deletions examples/text_to_image/train_text_to_image_sdxl.py
Original file line number Diff line number Diff line change
Expand Up @@ -951,6 +951,9 @@ def collate_fn(examples):
unet, optimizer, train_dataloader, lr_scheduler
)

if args.use_ema:
ema_unet.to(accelerator.device)

# We need to recalculate our total training steps as the size of the training dataloader may have changed.
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
if overrode_max_train_steps:
Expand Down Expand Up @@ -1126,6 +1129,8 @@ def compute_time_ids(original_size, crops_coords_top_left):

# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
if args.use_ema:
ema_unet.step(unet.parameters())
progress_bar.update(1)
global_step += 1
accelerator.log({"train_loss": train_loss}, step=global_step)
Expand Down

0 comments on commit ad310af

Please sign in to comment.