Skip to content

Commit c14057c

Browse files
authored
[LoRA] improve lora support for flux. (huggingface#10810)
update lora support for flux.
1 parent 3579cd2 commit c14057c

File tree

1 file changed

+53
-7
lines changed

1 file changed

+53
-7
lines changed

src/diffusers/loaders/lora_conversion_utils.py

+53-7
Original file line numberDiff line numberDiff line change
@@ -588,18 +588,23 @@ def _convert(original_key, diffusers_key, state_dict, new_state_dict):
588588
new_state_dict[diffusers_down_key.replace(".lora_A.", ".lora_B.")] = up_weight
589589

590590
all_unique_keys = {
591-
k.replace(".lora_down.weight", "").replace(".lora_up.weight", "").replace(".alpha", "") for k in state_dict
591+
k.replace(".lora_down.weight", "").replace(".lora_up.weight", "").replace(".alpha", "")
592+
for k in state_dict
593+
if not k.startswith(("lora_unet_"))
592594
}
593-
all_unique_keys = sorted(all_unique_keys)
594-
assert all("lora_transformer_" in k for k in all_unique_keys), f"{all_unique_keys=}"
595+
assert all(k.startswith(("lora_transformer_", "lora_te1_")) for k in all_unique_keys), f"{all_unique_keys=}"
595596

597+
has_te_keys = False
596598
for k in all_unique_keys:
597599
if k.startswith("lora_transformer_single_transformer_blocks_"):
598600
i = int(k.split("lora_transformer_single_transformer_blocks_")[-1].split("_")[0])
599601
diffusers_key = f"single_transformer_blocks.{i}"
600602
elif k.startswith("lora_transformer_transformer_blocks_"):
601603
i = int(k.split("lora_transformer_transformer_blocks_")[-1].split("_")[0])
602604
diffusers_key = f"transformer_blocks.{i}"
605+
elif k.startswith("lora_te1_"):
606+
has_te_keys = True
607+
continue
603608
else:
604609
raise NotImplementedError
605610

@@ -615,17 +620,57 @@ def _convert(original_key, diffusers_key, state_dict, new_state_dict):
615620
remaining = k.split("attn_")[-1]
616621
diffusers_key += f".attn.{remaining}"
617622

618-
if diffusers_key == f"transformer_blocks.{i}":
619-
print(k, diffusers_key)
620623
_convert(k, diffusers_key, state_dict, new_state_dict)
621624

625+
if has_te_keys:
626+
layer_pattern = re.compile(r"lora_te1_text_model_encoder_layers_(\d+)")
627+
attn_mapping = {
628+
"q_proj": ".self_attn.q_proj",
629+
"k_proj": ".self_attn.k_proj",
630+
"v_proj": ".self_attn.v_proj",
631+
"out_proj": ".self_attn.out_proj",
632+
}
633+
mlp_mapping = {"fc1": ".mlp.fc1", "fc2": ".mlp.fc2"}
634+
for k in all_unique_keys:
635+
if not k.startswith("lora_te1_"):
636+
continue
637+
638+
match = layer_pattern.search(k)
639+
if not match:
640+
continue
641+
i = int(match.group(1))
642+
diffusers_key = f"text_model.encoder.layers.{i}"
643+
644+
if "attn" in k:
645+
for key_fragment, suffix in attn_mapping.items():
646+
if key_fragment in k:
647+
diffusers_key += suffix
648+
break
649+
elif "mlp" in k:
650+
for key_fragment, suffix in mlp_mapping.items():
651+
if key_fragment in k:
652+
diffusers_key += suffix
653+
break
654+
655+
_convert(k, diffusers_key, state_dict, new_state_dict)
656+
657+
if state_dict:
658+
remaining_all_unet = all(k.startswith("lora_unet_") for k in state_dict)
659+
if remaining_all_unet:
660+
keys = list(state_dict.keys())
661+
for k in keys:
662+
state_dict.pop(k)
663+
622664
if len(state_dict) > 0:
623665
raise ValueError(
624666
f"Expected an empty state dict at this point but its has these keys which couldn't be parsed: {list(state_dict.keys())}."
625667
)
626668

627-
new_state_dict = {f"transformer.{k}": v for k, v in new_state_dict.items()}
628-
return new_state_dict
669+
transformer_state_dict = {
670+
f"transformer.{k}": v for k, v in new_state_dict.items() if not k.startswith("text_model.")
671+
}
672+
te_state_dict = {f"text_encoder.{k}": v for k, v in new_state_dict.items() if k.startswith("text_model.")}
673+
return {**transformer_state_dict, **te_state_dict}
629674

630675
# This is weird.
631676
# https://huggingface.co/sayakpaul/different-lora-from-civitai/tree/main?show_file_info=sharp_detailed_foot.safetensors
@@ -640,6 +685,7 @@ def _convert(original_key, diffusers_key, state_dict, new_state_dict):
640685
)
641686
if has_mixture:
642687
return _convert_mixture_state_dict_to_diffusers(state_dict)
688+
643689
return _convert_sd_scripts_to_ai_toolkit(state_dict)
644690

645691

0 commit comments

Comments
 (0)