Skip to content

Commit

Permalink
[textual inversion] add gradient checkpointing and small fixes. (hugg…
Browse files Browse the repository at this point in the history
…ingface#1848)

Co-authored-by: Henrik Forstén <[email protected]>

* update TI script

* make flake happy

* fix typo
  • Loading branch information
patil-suraj authored Dec 29, 2022
1 parent 03bf877 commit 9ea7052
Showing 1 changed file with 22 additions and 23 deletions.
45 changes: 22 additions & 23 deletions examples/textual_inversion/textual_inversion.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import argparse
import itertools
import math
import os
import random
Expand Down Expand Up @@ -147,6 +146,11 @@ def parse_args():
default=1,
help="Number of updates steps to accumulate before performing a backward/update pass.",
)
parser.add_argument(
"--gradient_checkpointing",
action="store_true",
help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
)
parser.add_argument(
"--learning_rate",
type=float,
Expand Down Expand Up @@ -383,11 +387,6 @@ def get_full_repo_name(model_id: str, organization: Optional[str] = None, token:
return f"{organization}/{model_id}"


def freeze_params(params):
for param in params:
param.requires_grad = False


def main():
args = parse_args()
logging_dir = os.path.join(args.output_dir, args.logging_dir)
Expand Down Expand Up @@ -460,6 +459,10 @@ def main():
revision=args.revision,
)

if args.gradient_checkpointing:
text_encoder.gradient_checkpointing_enable()
unet.enable_gradient_checkpointing()

if args.enable_xformers_memory_efficient_attention:
if is_xformers_available():
unet.enable_xformers_memory_efficient_attention()
Expand All @@ -474,15 +477,12 @@ def main():
token_embeds[placeholder_token_id] = token_embeds[initializer_token_id]

# Freeze vae and unet
freeze_params(vae.parameters())
freeze_params(unet.parameters())
vae.requires_grad_(False)
unet.requires_grad_(False)
# Freeze all parameters except for the token embeddings in text encoder
params_to_freeze = itertools.chain(
text_encoder.text_model.encoder.parameters(),
text_encoder.text_model.final_layer_norm.parameters(),
text_encoder.text_model.embeddings.position_embedding.parameters(),
)
freeze_params(params_to_freeze)
text_encoder.text_model.encoder.requires_grad_(False)
text_encoder.text_model.final_layer_norm.requires_grad_(False)
text_encoder.text_model.embeddings.position_embedding.requires_grad_(False)

if args.scale_lr:
args.learning_rate = (
Expand Down Expand Up @@ -541,9 +541,10 @@ def main():
unet.to(accelerator.device, dtype=weight_dtype)
vae.to(accelerator.device, dtype=weight_dtype)

# Keep vae and unet in eval model as we don't train these
vae.eval()
unet.eval()
# Keep unet in train mode if we are using gradient checkpointing to save memory.
# The dropout is 0 so it doesn't matter if we are in eval or train mode.
if args.gradient_checkpointing:
unet.train()

# We need to recalculate our total training steps as the size of the training dataloader may have changed.
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
Expand Down Expand Up @@ -609,12 +610,11 @@ def main():
latents = latents * 0.18215

# Sample noise that we'll add to the latents
noise = torch.randn(latents.shape).to(latents.device).to(dtype=weight_dtype)
noise = torch.randn_like(latents)
bsz = latents.shape[0]
# Sample a random timestep for each image
timesteps = torch.randint(
0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device
).long()
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
timesteps = timesteps.long()

# Add noise to the latents according to the noise magnitude at each timestep
# (this is the forward diffusion process)
Expand Down Expand Up @@ -669,8 +669,7 @@ def main():
if global_step >= args.max_train_steps:
break

accelerator.wait_for_everyone()

accelerator.wait_for_everyone()
# Create the pipeline using using the trained modules and save it.
if accelerator.is_main_process:
if args.push_to_hub and args.only_save_embeds:
Expand Down

0 comments on commit 9ea7052

Please sign in to comment.