Skip to content

Commit

Permalink
[textual_inversion] unwrap_model text encoder before accessing weights (
Browse files Browse the repository at this point in the history
huggingface#1816)

* unwrap_model text encoder before accessing weights

* fix another call

* fix the right call
  • Loading branch information
patil-suraj authored Dec 23, 2022
1 parent f2acfb6 commit 9be94d9
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions examples/textual_inversion/textual_inversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -592,7 +592,7 @@ def main():
progress_bar.set_description("Steps")

# keep original embeddings as reference
orig_embeds_params = text_encoder.get_input_embeddings().weight.data.clone()
orig_embeds_params = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight.data.clone()

for epoch in range(first_epoch, args.num_train_epochs):
text_encoder.train()
Expand Down Expand Up @@ -644,7 +644,9 @@ def main():
# Let's make sure we don't update any embedding weights besides the newly added token
index_no_updates = torch.arange(len(tokenizer)) != placeholder_token_id
with torch.no_grad():
text_encoder.get_input_embeddings().weight[index_no_updates] = orig_embeds_params[index_no_updates]
accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[
index_no_updates
] = orig_embeds_params[index_no_updates]

# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
Expand Down

0 comments on commit 9be94d9

Please sign in to comment.