Skip to content

Commit

Permalink
FIX [PEFT / Core] Copy the state dict when passing it to `load_lo…
Browse files Browse the repository at this point in the history
…ra_weights` (huggingface#7058)

* copy the state dict in load lora weights

* fixup
  • Loading branch information
younesbelkada authored Feb 27, 2024
1 parent 5aa31bd commit 8a69273
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 0 deletions.
8 changes: 8 additions & 0 deletions src/diffusers/loaders/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,10 @@ def load_lora_weights(
if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for this method.")

# if a dict is passed, copy it instead of modifying it inplace
if isinstance(pretrained_model_name_or_path_or_dict, dict):
pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()

# First, ensure that the checkpoint is a compatible one and can be successfully loaded.
state_dict, network_alphas = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)

Expand Down Expand Up @@ -1229,6 +1233,10 @@ def load_lora_weights(
# it here explicitly to be able to tell that it's coming from an SDXL
# pipeline.

# if a dict is passed, copy it instead of modifying it inplace
if isinstance(pretrained_model_name_or_path_or_dict, dict):
pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()

# First, ensure that the checkpoint is a compatible one and can be successfully loaded.
state_dict, network_alphas = self.lora_state_dict(
pretrained_model_name_or_path_or_dict,
Expand Down
36 changes: 36 additions & 0 deletions tests/lora/test_lora_layers_peft.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,13 @@
from huggingface_hub import hf_hub_download
from huggingface_hub.repocard import RepoCard
from packaging import version
from safetensors.torch import load_file
from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer

from diffusers import (
AutoencoderKL,
AutoPipelineForImage2Image,
AutoPipelineForText2Image,
ControlNetModel,
DDIMScheduler,
DiffusionPipeline,
Expand Down Expand Up @@ -1745,6 +1747,40 @@ def test_load_unload_load_kohya_lora(self):
self.assertTrue(np.allclose(lora_images, lora_images_again, atol=1e-3))
release_memory(pipe)

def test_not_empty_state_dict(self):
# Makes sure https://github.com/huggingface/diffusers/issues/7054 does not happen again
pipe = AutoPipelineForText2Image.from_pretrained(
"runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16
).to("cuda")
pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config)

cached_file = hf_hub_download("hf-internal-testing/lcm-lora-test-sd-v1-5", "test_lora.safetensors")
lcm_lora = load_file(cached_file)

pipe.load_lora_weights(lcm_lora, adapter_name="lcm")
self.assertTrue(lcm_lora != {})
release_memory(pipe)

def test_load_unload_load_state_dict(self):
# Makes sure https://github.com/huggingface/diffusers/issues/7054 does not happen again
pipe = AutoPipelineForText2Image.from_pretrained(
"runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16
).to("cuda")
pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config)

cached_file = hf_hub_download("hf-internal-testing/lcm-lora-test-sd-v1-5", "test_lora.safetensors")
lcm_lora = load_file(cached_file)
previous_state_dict = lcm_lora.copy()

pipe.load_lora_weights(lcm_lora, adapter_name="lcm")
self.assertDictEqual(lcm_lora, previous_state_dict)

pipe.unload_lora_weights()
pipe.load_lora_weights(lcm_lora, adapter_name="lcm")
self.assertDictEqual(lcm_lora, previous_state_dict)

release_memory(pipe)


@slow
@require_torch_gpu
Expand Down

0 comments on commit 8a69273

Please sign in to comment.