Skip to content

Commit

Permalink
Added script to save during textual inversion training. Issue 524 (hu…
Browse files Browse the repository at this point in the history
…ggingface#645)

* Added script to save during training

* Suggested changes
  • Loading branch information
isamu-isozaki authored Sep 28, 2022
1 parent 765506c commit 7f31142
Showing 1 changed file with 16 additions and 3 deletions.
19 changes: 16 additions & 3 deletions examples/textual_inversion/textual_inversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,21 @@
logger = get_logger(__name__)


def save_progress(text_encoder, placeholder_token_id, accelerator, args):
logger.info("Saving embeddings")
learned_embeds = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[placeholder_token_id]
learned_embeds_dict = {args.placeholder_token: learned_embeds.detach().cpu()}
torch.save(learned_embeds_dict, os.path.join(args.output_dir, "learned_embeds.bin"))


def parse_args():
parser = argparse.ArgumentParser(description="Simple example of a training script.")
parser.add_argument(
"--save_steps",
type=int,
default=500,
help="Save learned_embeds.bin every X updates steps.",
)
parser.add_argument(
"--pretrained_model_name_or_path",
type=str,
Expand Down Expand Up @@ -542,6 +555,8 @@ def main():
if accelerator.sync_gradients:
progress_bar.update(1)
global_step += 1
if global_step % args.save_steps == 0:
save_progress(text_encoder, placeholder_token_id, accelerator, args)

logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
progress_bar.set_postfix(**logs)
Expand All @@ -567,9 +582,7 @@ def main():
)
pipeline.save_pretrained(args.output_dir)
# Also save the newly trained embeddings
learned_embeds = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[placeholder_token_id]
learned_embeds_dict = {args.placeholder_token: learned_embeds.detach().cpu()}
torch.save(learned_embeds_dict, os.path.join(args.output_dir, "learned_embeds.bin"))
save_progress(text_encoder, placeholder_token_id, accelerator, args)

if args.push_to_hub:
repo.push_to_hub(
Expand Down

0 comments on commit 7f31142

Please sign in to comment.