Skip to content

Commit b0550a6

Browse files
authored
[LoRA] restrict certain keys to be checked for peft config update. (huggingface#10808)
* restruct certain keys to be checked for peft config update. * updates * finish./ * finish 2. * updates
1 parent 6f74ef5 commit b0550a6

File tree

1 file changed

+26
-9
lines changed

1 file changed

+26
-9
lines changed

src/diffusers/loaders/peft.py

+26-9
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,9 @@ def _maybe_adjust_config(config):
6363
method removes the ambiguity by following what is described here:
6464
https://github.com/huggingface/diffusers/pull/9985#issuecomment-2493840028.
6565
"""
66+
# Track keys that have been explicitly removed to prevent re-adding them.
67+
deleted_keys = set()
68+
6669
rank_pattern = config["rank_pattern"].copy()
6770
target_modules = config["target_modules"]
6871
original_r = config["r"]
@@ -80,21 +83,22 @@ def _maybe_adjust_config(config):
8083
ambiguous_key = key
8184

8285
if exact_matches and substring_matches:
83-
# if ambiguous we update the rank associated with the ambiguous key (`proj_out`, for example)
86+
# if ambiguous, update the rank associated with the ambiguous key (`proj_out`, for example)
8487
config["r"] = key_rank
85-
# remove the ambiguous key from `rank_pattern` and update its rank to `r`, instead
88+
# remove the ambiguous key from `rank_pattern` and record it as deleted
8689
del config["rank_pattern"][key]
90+
deleted_keys.add(key)
91+
# For substring matches, add them with the original rank only if they haven't been assigned already
8792
for mod in substring_matches:
88-
# avoid overwriting if the module already has a specific rank
89-
if mod not in config["rank_pattern"]:
93+
if mod not in config["rank_pattern"] and mod not in deleted_keys:
9094
config["rank_pattern"][mod] = original_r
9195

92-
# update the rest of the keys with the `original_r`
96+
# Update the rest of the target modules with the original rank if not already set and not deleted
9397
for mod in target_modules:
94-
if mod != ambiguous_key and mod not in config["rank_pattern"]:
98+
if mod != ambiguous_key and mod not in config["rank_pattern"] and mod not in deleted_keys:
9599
config["rank_pattern"][mod] = original_r
96100

97-
# handle alphas to deal with cases like
101+
# Handle alphas to deal with cases like:
98102
# https://github.com/huggingface/diffusers/pull/9999#issuecomment-2516180777
99103
has_different_ranks = len(config["rank_pattern"]) > 1 and list(config["rank_pattern"])[0] != config["r"]
100104
if has_different_ranks:
@@ -187,6 +191,11 @@ def load_lora_adapter(self, pretrained_model_name_or_path_or_dict, prefix="trans
187191
from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict
188192
from peft.tuners.tuners_utils import BaseTunerLayer
189193

194+
try:
195+
from peft.utils.constants import FULLY_QUALIFIED_PATTERN_KEY_PREFIX
196+
except ImportError:
197+
FULLY_QUALIFIED_PATTERN_KEY_PREFIX = None
198+
190199
cache_dir = kwargs.pop("cache_dir", None)
191200
force_download = kwargs.pop("force_download", False)
192201
proxies = kwargs.pop("proxies", None)
@@ -251,14 +260,22 @@ def load_lora_adapter(self, pretrained_model_name_or_path_or_dict, prefix="trans
251260
# Cannot figure out rank from lora layers that don't have atleast 2 dimensions.
252261
# Bias layers in LoRA only have a single dimension
253262
if "lora_B" in key and val.ndim > 1:
254-
rank[key] = val.shape[1]
263+
# Support to handle cases where layer patterns are treated as full layer names
264+
# was added later in PEFT. So, we handle it accordingly.
265+
# TODO: when we fix the minimal PEFT version for Diffusers,
266+
# we should remove `_maybe_adjust_config()`.
267+
if FULLY_QUALIFIED_PATTERN_KEY_PREFIX:
268+
rank[f"{FULLY_QUALIFIED_PATTERN_KEY_PREFIX}{key}"] = val.shape[1]
269+
else:
270+
rank[key] = val.shape[1]
255271

256272
if network_alphas is not None and len(network_alphas) >= 1:
257273
alpha_keys = [k for k in network_alphas.keys() if k.startswith(f"{prefix}.")]
258274
network_alphas = {k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys}
259275

260276
lora_config_kwargs = get_peft_kwargs(rank, network_alpha_dict=network_alphas, peft_state_dict=state_dict)
261-
lora_config_kwargs = _maybe_adjust_config(lora_config_kwargs)
277+
if not FULLY_QUALIFIED_PATTERN_KEY_PREFIX:
278+
lora_config_kwargs = _maybe_adjust_config(lora_config_kwargs)
262279

263280
if "use_dora" in lora_config_kwargs:
264281
if lora_config_kwargs["use_dora"]:

0 commit comments

Comments
 (0)