Skip to content

Commit

Permalink
fix lora conversion script
Browse files Browse the repository at this point in the history
  • Loading branch information
tianweiy committed May 30, 2024
1 parent b2cb9b4 commit 4b844a8
Showing 1 changed file with 21 additions and 4 deletions.
25 changes: 21 additions & 4 deletions main/sdxl/extract_lora_module.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from diffusers import UNet2DConditionModel, StableDiffusionXLPipeline
from diffusers.utils import convert_state_dict_to_diffusers
from peft import LoraConfig, get_peft_model_state_dict
from diffusers import UNet2DConditionModel
from safetensors.torch import save_file
import argparse
import torch

Expand Down Expand Up @@ -49,9 +49,26 @@ def main():
if args.fp16:
generator = generator.half()

unet_lora_state_dict = convert_state_dict_to_diffusers(get_peft_model_state_dict(generator))
StableDiffusionXLPipeline.save_lora_weights(args.output_model_path, unet_lora_layers=unet_lora_state_dict)
unet_lora_state_dict = get_peft_model_state_dict(generator)

new_state_dict = {}

for k, v in unet_lora_state_dict.items():

if "lora_A" in k:
k = k.replace("lora_A", "lora_down")
elif "lora_B" in k:
k = k.replace("lora_B", "lora_up")

k = "lora_unet_" + "_".join(k.split(".")[:-2]) + "." + ".".join(k.split(".")[-2:])

new_state_dict[k] = v

alpha_key = k[:k.find(".")]+".alpha"

new_state_dict[alpha_key] = torch.tensor(args.lora_alpha, dtype=torch.float16 if args.fp16 else torch.float32)

save_file(new_state_dict, args.output_model_path)

if __name__ == "__main__":
main()

0 comments on commit 4b844a8

Please sign in to comment.