@@ -588,18 +588,23 @@ def _convert(original_key, diffusers_key, state_dict, new_state_dict):
588
588
new_state_dict [diffusers_down_key .replace (".lora_A." , ".lora_B." )] = up_weight
589
589
590
590
all_unique_keys = {
591
- k .replace (".lora_down.weight" , "" ).replace (".lora_up.weight" , "" ).replace (".alpha" , "" ) for k in state_dict
591
+ k .replace (".lora_down.weight" , "" ).replace (".lora_up.weight" , "" ).replace (".alpha" , "" )
592
+ for k in state_dict
593
+ if not k .startswith (("lora_unet_" ))
592
594
}
593
- all_unique_keys = sorted (all_unique_keys )
594
- assert all ("lora_transformer_" in k for k in all_unique_keys ), f"{ all_unique_keys = } "
595
+ assert all (k .startswith (("lora_transformer_" , "lora_te1_" )) for k in all_unique_keys ), f"{ all_unique_keys = } "
595
596
597
+ has_te_keys = False
596
598
for k in all_unique_keys :
597
599
if k .startswith ("lora_transformer_single_transformer_blocks_" ):
598
600
i = int (k .split ("lora_transformer_single_transformer_blocks_" )[- 1 ].split ("_" )[0 ])
599
601
diffusers_key = f"single_transformer_blocks.{ i } "
600
602
elif k .startswith ("lora_transformer_transformer_blocks_" ):
601
603
i = int (k .split ("lora_transformer_transformer_blocks_" )[- 1 ].split ("_" )[0 ])
602
604
diffusers_key = f"transformer_blocks.{ i } "
605
+ elif k .startswith ("lora_te1_" ):
606
+ has_te_keys = True
607
+ continue
603
608
else :
604
609
raise NotImplementedError
605
610
@@ -615,17 +620,57 @@ def _convert(original_key, diffusers_key, state_dict, new_state_dict):
615
620
remaining = k .split ("attn_" )[- 1 ]
616
621
diffusers_key += f".attn.{ remaining } "
617
622
618
- if diffusers_key == f"transformer_blocks.{ i } " :
619
- print (k , diffusers_key )
620
623
_convert (k , diffusers_key , state_dict , new_state_dict )
621
624
625
+ if has_te_keys :
626
+ layer_pattern = re .compile (r"lora_te1_text_model_encoder_layers_(\d+)" )
627
+ attn_mapping = {
628
+ "q_proj" : ".self_attn.q_proj" ,
629
+ "k_proj" : ".self_attn.k_proj" ,
630
+ "v_proj" : ".self_attn.v_proj" ,
631
+ "out_proj" : ".self_attn.out_proj" ,
632
+ }
633
+ mlp_mapping = {"fc1" : ".mlp.fc1" , "fc2" : ".mlp.fc2" }
634
+ for k in all_unique_keys :
635
+ if not k .startswith ("lora_te1_" ):
636
+ continue
637
+
638
+ match = layer_pattern .search (k )
639
+ if not match :
640
+ continue
641
+ i = int (match .group (1 ))
642
+ diffusers_key = f"text_model.encoder.layers.{ i } "
643
+
644
+ if "attn" in k :
645
+ for key_fragment , suffix in attn_mapping .items ():
646
+ if key_fragment in k :
647
+ diffusers_key += suffix
648
+ break
649
+ elif "mlp" in k :
650
+ for key_fragment , suffix in mlp_mapping .items ():
651
+ if key_fragment in k :
652
+ diffusers_key += suffix
653
+ break
654
+
655
+ _convert (k , diffusers_key , state_dict , new_state_dict )
656
+
657
+ if state_dict :
658
+ remaining_all_unet = all (k .startswith ("lora_unet_" ) for k in state_dict )
659
+ if remaining_all_unet :
660
+ keys = list (state_dict .keys ())
661
+ for k in keys :
662
+ state_dict .pop (k )
663
+
622
664
if len (state_dict ) > 0 :
623
665
raise ValueError (
624
666
f"Expected an empty state dict at this point but its has these keys which couldn't be parsed: { list (state_dict .keys ())} ."
625
667
)
626
668
627
- new_state_dict = {f"transformer.{ k } " : v for k , v in new_state_dict .items ()}
628
- return new_state_dict
669
+ transformer_state_dict = {
670
+ f"transformer.{ k } " : v for k , v in new_state_dict .items () if not k .startswith ("text_model." )
671
+ }
672
+ te_state_dict = {f"text_encoder.{ k } " : v for k , v in new_state_dict .items () if k .startswith ("text_model." )}
673
+ return {** transformer_state_dict , ** te_state_dict }
629
674
630
675
# This is weird.
631
676
# https://huggingface.co/sayakpaul/different-lora-from-civitai/tree/main?show_file_info=sharp_detailed_foot.safetensors
@@ -640,6 +685,7 @@ def _convert(original_key, diffusers_key, state_dict, new_state_dict):
640
685
)
641
686
if has_mixture :
642
687
return _convert_mixture_state_dict_to_diffusers (state_dict )
688
+
643
689
return _convert_sd_scripts_to_ai_toolkit (state_dict )
644
690
645
691
0 commit comments