Skip to content

Commit

Permalink
[examples/advanced_diffusion_training] bug fixes and improvements for…
Browse files Browse the repository at this point in the history
… LoRA Dreambooth SDXL advanced training script (huggingface#5935)

* imports and readme bug fixes

* bug fix - ensures text_encoder params are dtype==float32 (when using pivotal tuning) even if the rest of the model is loaded in fp16

* added pivotal tuning to readme

* mapping token identifier to new inserted token in validation prompt (if used)

* correct default value of --train_text_encoder_frac

* change default value of  --adam_weight_decay_text_encoder

* validation prompt generations when using pivotal tuning bug fix

* style fix

* textual inversion embeddings name change

* style fix

* bug fix - stopping text encoder optimization halfway

* readme - will include token abstraction and new inserted tokens when using pivotal tuning
- added type to --num_new_tokens_per_abstraction

* style fix

---------

Co-authored-by: Linoy Tsaban <[email protected]>
  • Loading branch information
linoytsaban and linoytsaban authored Dec 1, 2023
1 parent 7d4a257 commit d29d97b
Showing 1 changed file with 92 additions and 56 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@
UNet2DConditionModel,
)
from diffusers.loaders import LoraLoaderMixin
from diffusers.models.lora import LoRALinearLayer, text_encoder_lora_state_dict
from diffusers.models.lora import LoRALinearLayer
from diffusers.optimization import get_scheduler
from diffusers.training_utils import compute_snr, unet_lora_state_dict
from diffusers.utils import check_min_version, is_wandb_available
Expand All @@ -67,11 +67,46 @@
logger = get_logger(__name__)


# TODO: This function should be removed once training scripts are rewritten in PEFT
def text_encoder_lora_state_dict(text_encoder):
state_dict = {}

def text_encoder_attn_modules(text_encoder):
from transformers import CLIPTextModel, CLIPTextModelWithProjection

attn_modules = []

if isinstance(text_encoder, (CLIPTextModel, CLIPTextModelWithProjection)):
for i, layer in enumerate(text_encoder.text_model.encoder.layers):
name = f"text_model.encoder.layers.{i}.self_attn"
mod = layer.self_attn
attn_modules.append((name, mod))

return attn_modules

for name, module in text_encoder_attn_modules(text_encoder):
for k, v in module.q_proj.lora_linear_layer.state_dict().items():
state_dict[f"{name}.q_proj.lora_linear_layer.{k}"] = v

for k, v in module.k_proj.lora_linear_layer.state_dict().items():
state_dict[f"{name}.k_proj.lora_linear_layer.{k}"] = v

for k, v in module.v_proj.lora_linear_layer.state_dict().items():
state_dict[f"{name}.v_proj.lora_linear_layer.{k}"] = v

for k, v in module.out_proj.lora_linear_layer.state_dict().items():
state_dict[f"{name}.out_proj.lora_linear_layer.{k}"] = v

return state_dict


def save_model_card(
repo_id: str,
images=None,
base_model=str,
train_text_encoder=False,
train_text_encoder_ti=False,
token_abstraction_dict=None,
instance_prompt=str,
validation_prompt=str,
repo_folder=None,
Expand All @@ -83,10 +118,23 @@ def save_model_card(
img_str += f"""
- text: '{validation_prompt if validation_prompt else ' ' }'
output:
url: >-
url:
"image_{i}.png"
"""

trigger_str = f"You should use {instance_prompt} to trigger the image generation."
if train_text_encoder_ti:
trigger_str = (
"To trigger image generation of trained concept(or concepts) replace each concept identifier "
"in you prompt with the new inserted tokens:\n"
)
if token_abstraction_dict:
for key, value in token_abstraction_dict.items():
tokens = "".join(value)
trigger_str += f"""
to trigger concept {key}-> use {tokens} in your prompt \n
"""

yaml = f"""
---
tags:
Expand All @@ -96,9 +144,7 @@ def save_model_card(
- diffusers
- lora
- template:sd-lora
widget:
{img_str}
---
base_model: {base_model}
instance_prompt: {instance_prompt}
license: openrail++
Expand All @@ -112,14 +158,19 @@ def save_model_card(
## Model description
These are {repo_id} LoRA adaption weights for {base_model}.
### These are {repo_id} LoRA adaption weights for {base_model}.
The weights were trained using [DreamBooth](https://dreambooth.github.io/).
LoRA for the text encoder was enabled: {train_text_encoder}.
Pivotal tuning was enabled: {train_text_encoder_ti}.
Special VAE used for training: {vae_path}.
## Trigger words
You should use {instance_prompt} to trigger the image generation.
{trigger_str}
## Download model
Expand Down Expand Up @@ -244,6 +295,7 @@ def parse_args(input_args=None):

parser.add_argument(
"--num_new_tokens_per_abstraction",
type=int,
default=2,
help="number of new tokens inserted to the tokenizers per token_abstraction value when "
"--train_text_encoder_ti = True. By default, each --token_abstraction (e.g. TOK) is mapped to 2 new "
Expand Down Expand Up @@ -455,7 +507,7 @@ def parse_args(input_args=None):
parser.add_argument(
"--train_text_encoder_frac",
type=float,
default=0.5,
default=1.0,
help=("The percentage of epochs to perform text encoder tuning"),
)

Expand Down Expand Up @@ -488,7 +540,7 @@ def parse_args(input_args=None):
parser.add_argument("--prodigy_decouple", type=bool, default=True, help="Use AdamW style decoupled weight decay")
parser.add_argument("--adam_weight_decay", type=float, default=1e-04, help="Weight decay to use for unet params")
parser.add_argument(
"--adam_weight_decay_text_encoder", type=float, default=1e-03, help="Weight decay to use for text_encoder"
"--adam_weight_decay_text_encoder", type=float, default=None, help="Weight decay to use for text_encoder"
)

parser.add_argument(
Expand Down Expand Up @@ -679,12 +731,19 @@ def initialize_new_tokens(self, inserting_toks: List[str]):
def save_embeddings(self, file_path: str):
assert self.train_ids is not None, "Initialize new tokens before saving embeddings."
tensors = {}
# text_encoder_0 - CLIP ViT-L/14, text_encoder_1 - CLIP ViT-G/14
idx_to_text_encoder_name = {0: "clip_l", 1: "clip_g"}
for idx, text_encoder in enumerate(self.text_encoders):
assert text_encoder.text_model.embeddings.token_embedding.weight.data.shape[0] == len(
self.tokenizers[0]
), "Tokenizers should be the same."
new_token_embeddings = text_encoder.text_model.embeddings.token_embedding.weight.data[self.train_ids]
tensors[f"text_encoders_{idx}"] = new_token_embeddings

# New tokens for each text encoder are saved under "clip_l" (for text_encoder 0), "clip_g" (for
# text_encoder 1) to keep compatible with the ecosystem.
# Note: When loading with diffusers, any name can work - simply specify in inference
tensors[idx_to_text_encoder_name[idx]] = new_token_embeddings
# tensors[f"text_encoders_{idx}"] = new_token_embeddings

save_file(tensors, file_path)

Expand All @@ -696,19 +755,6 @@ def dtype(self):
def device(self):
return self.text_encoders[0].device

# def _load_embeddings(self, loaded_embeddings, tokenizer, text_encoder):
# # Assuming new tokens are of the format <s_i>
# self.inserting_toks = [f"<s{i}>" for i in range(loaded_embeddings.shape[0])]
# special_tokens_dict = {"additional_special_tokens": self.inserting_toks}
# tokenizer.add_special_tokens(special_tokens_dict)
# text_encoder.resize_token_embeddings(len(tokenizer))
#
# self.train_ids = tokenizer.convert_tokens_to_ids(self.inserting_toks)
# assert self.train_ids is not None, "New tokens could not be converted to IDs."
# text_encoder.text_model.embeddings.token_embedding.weight.data[
# self.train_ids
# ] = loaded_embeddings.to(device=self.device).to(dtype=self.dtype)

@torch.no_grad()
def retract_embeddings(self):
for idx, text_encoder in enumerate(self.text_encoders):
Expand All @@ -730,15 +776,6 @@ def retract_embeddings(self):
new_embeddings = new_embeddings * (off_ratio**0.1)
text_encoder.text_model.embeddings.token_embedding.weight.data[index_updates] = new_embeddings

# def load_embeddings(self, file_path: str):
# with safe_open(file_path, framework="pt", device=self.device.type) as f:
# for idx in range(len(self.text_encoders)):
# text_encoder = self.text_encoders[idx]
# tokenizer = self.tokenizers[idx]
#
# loaded_embeddings = f.get_tensor(f"text_encoders_{idx}")
# self._load_embeddings(loaded_embeddings, tokenizer, text_encoder)


class DreamBoothDataset(Dataset):
"""
Expand Down Expand Up @@ -1216,13 +1253,17 @@ def main(args):
text_lora_parameters_one = []
for name, param in text_encoder_one.named_parameters():
if "token_embedding" in name:
# ensure that dtype is float32, even if rest of the model that isn't trained is loaded in fp16
param = param.to(dtype=torch.float32)
param.requires_grad = True
text_lora_parameters_one.append(param)
else:
param.requires_grad = False
text_lora_parameters_two = []
for name, param in text_encoder_two.named_parameters():
if "token_embedding" in name:
# ensure that dtype is float32, even if rest of the model that isn't trained is loaded in fp16
param = param.to(dtype=torch.float32)
param.requires_grad = True
text_lora_parameters_two.append(param)
else:
Expand Down Expand Up @@ -1309,12 +1350,16 @@ def load_model_hook(models, input_dir):
# different learning rate for text encoder and unet
text_lora_parameters_one_with_lr = {
"params": text_lora_parameters_one,
"weight_decay": args.adam_weight_decay_text_encoder,
"weight_decay": args.adam_weight_decay_text_encoder
if args.adam_weight_decay_text_encoder
else args.adam_weight_decay,
"lr": args.text_encoder_lr if args.text_encoder_lr else args.learning_rate,
}
text_lora_parameters_two_with_lr = {
"params": text_lora_parameters_two,
"weight_decay": args.adam_weight_decay_text_encoder,
"weight_decay": args.adam_weight_decay_text_encoder
if args.adam_weight_decay_text_encoder
else args.adam_weight_decay,
"lr": args.text_encoder_lr if args.text_encoder_lr else args.learning_rate,
}
params_to_optimize = [
Expand Down Expand Up @@ -1494,6 +1539,12 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
tokens_one = torch.cat([tokens_one, class_tokens_one], dim=0)
tokens_two = torch.cat([tokens_two, class_tokens_two], dim=0)

if args.train_text_encoder_ti and args.validation_prompt:
# replace instances of --token_abstraction in validation prompt with the new tokens: "<si><si+1>" etc.
for token_abs, token_replacement in train_dataset.token_abstraction_dict.items():
args.validation_prompt = args.validation_prompt.replace(token_abs, "".join(token_replacement))
print("validation prompt:", args.validation_prompt)

# Scheduler and math around the number of training steps.
overrode_max_train_steps = False
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
Expand Down Expand Up @@ -1593,27 +1644,10 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
if epoch == num_train_epochs_text_encoder:
print("PIVOT HALFWAY", epoch)
# stopping optimization of text_encoder params
params_to_optimize = params_to_optimize[:1]
# reinitializing the optimizer to optimize only on unet params
if args.optimizer.lower() == "prodigy":
optimizer = optimizer_class(
params_to_optimize,
lr=args.learning_rate,
betas=(args.adam_beta1, args.adam_beta2),
beta3=args.prodigy_beta3,
weight_decay=args.adam_weight_decay,
eps=args.adam_epsilon,
decouple=args.prodigy_decouple,
use_bias_correction=args.prodigy_use_bias_correction,
safeguard_warmup=args.prodigy_safeguard_warmup,
)
else: # AdamW or 8-bit-AdamW
optimizer = optimizer_class(
params_to_optimize,
betas=(args.adam_beta1, args.adam_beta2),
weight_decay=args.adam_weight_decay,
eps=args.adam_epsilon,
)
# re setting the optimizer to optimize only on unet params
optimizer.param_groups[1]["lr"] = 0.0
optimizer.param_groups[2]["lr"] = 0.0

else:
# still optimizng the text encoder
text_encoder_one.train()
Expand All @@ -1628,7 +1662,7 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
with accelerator.accumulate(unet):
pixel_values = batch["pixel_values"].to(dtype=vae.dtype)
prompts = batch["prompts"]
print(prompts)
# print(prompts)
# encode batch prompts when custom prompts are provided for each image -
if train_dataset.custom_instance_prompts:
if freeze_text_encoder:
Expand Down Expand Up @@ -1801,7 +1835,7 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
f" {args.validation_prompt}."
)
# create pipeline
if not args.train_text_encoder:
if freeze_text_encoder:
text_encoder_one = text_encoder_cls_one.from_pretrained(
args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
)
Expand Down Expand Up @@ -1948,6 +1982,8 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
images=images,
base_model=args.pretrained_model_name_or_path,
train_text_encoder=args.train_text_encoder,
train_text_encoder_ti=args.train_text_encoder_ti,
token_abstraction_dict=train_dataset.token_abstraction_dict,
instance_prompt=args.instance_prompt,
validation_prompt=args.validation_prompt,
repo_folder=args.output_dir,
Expand Down

0 comments on commit d29d97b

Please sign in to comment.