@@ -63,6 +63,9 @@ def _maybe_adjust_config(config):
63
63
method removes the ambiguity by following what is described here:
64
64
https://github.com/huggingface/diffusers/pull/9985#issuecomment-2493840028.
65
65
"""
66
+ # Track keys that have been explicitly removed to prevent re-adding them.
67
+ deleted_keys = set ()
68
+
66
69
rank_pattern = config ["rank_pattern" ].copy ()
67
70
target_modules = config ["target_modules" ]
68
71
original_r = config ["r" ]
@@ -80,21 +83,22 @@ def _maybe_adjust_config(config):
80
83
ambiguous_key = key
81
84
82
85
if exact_matches and substring_matches :
83
- # if ambiguous we update the rank associated with the ambiguous key (`proj_out`, for example)
86
+ # if ambiguous, update the rank associated with the ambiguous key (`proj_out`, for example)
84
87
config ["r" ] = key_rank
85
- # remove the ambiguous key from `rank_pattern` and update its rank to `r`, instead
88
+ # remove the ambiguous key from `rank_pattern` and record it as deleted
86
89
del config ["rank_pattern" ][key ]
90
+ deleted_keys .add (key )
91
+ # For substring matches, add them with the original rank only if they haven't been assigned already
87
92
for mod in substring_matches :
88
- # avoid overwriting if the module already has a specific rank
89
- if mod not in config ["rank_pattern" ]:
93
+ if mod not in config ["rank_pattern" ] and mod not in deleted_keys :
90
94
config ["rank_pattern" ][mod ] = original_r
91
95
92
- # update the rest of the keys with the `original_r`
96
+ # Update the rest of the target modules with the original rank if not already set and not deleted
93
97
for mod in target_modules :
94
- if mod != ambiguous_key and mod not in config ["rank_pattern" ]:
98
+ if mod != ambiguous_key and mod not in config ["rank_pattern" ] and mod not in deleted_keys :
95
99
config ["rank_pattern" ][mod ] = original_r
96
100
97
- # handle alphas to deal with cases like
101
+ # Handle alphas to deal with cases like:
98
102
# https://github.com/huggingface/diffusers/pull/9999#issuecomment-2516180777
99
103
has_different_ranks = len (config ["rank_pattern" ]) > 1 and list (config ["rank_pattern" ])[0 ] != config ["r" ]
100
104
if has_different_ranks :
@@ -187,6 +191,11 @@ def load_lora_adapter(self, pretrained_model_name_or_path_or_dict, prefix="trans
187
191
from peft import LoraConfig , inject_adapter_in_model , set_peft_model_state_dict
188
192
from peft .tuners .tuners_utils import BaseTunerLayer
189
193
194
+ try :
195
+ from peft .utils .constants import FULLY_QUALIFIED_PATTERN_KEY_PREFIX
196
+ except ImportError :
197
+ FULLY_QUALIFIED_PATTERN_KEY_PREFIX = None
198
+
190
199
cache_dir = kwargs .pop ("cache_dir" , None )
191
200
force_download = kwargs .pop ("force_download" , False )
192
201
proxies = kwargs .pop ("proxies" , None )
@@ -251,14 +260,22 @@ def load_lora_adapter(self, pretrained_model_name_or_path_or_dict, prefix="trans
251
260
# Cannot figure out rank from lora layers that don't have atleast 2 dimensions.
252
261
# Bias layers in LoRA only have a single dimension
253
262
if "lora_B" in key and val .ndim > 1 :
254
- rank [key ] = val .shape [1 ]
263
+ # Support to handle cases where layer patterns are treated as full layer names
264
+ # was added later in PEFT. So, we handle it accordingly.
265
+ # TODO: when we fix the minimal PEFT version for Diffusers,
266
+ # we should remove `_maybe_adjust_config()`.
267
+ if FULLY_QUALIFIED_PATTERN_KEY_PREFIX :
268
+ rank [f"{ FULLY_QUALIFIED_PATTERN_KEY_PREFIX } { key } " ] = val .shape [1 ]
269
+ else :
270
+ rank [key ] = val .shape [1 ]
255
271
256
272
if network_alphas is not None and len (network_alphas ) >= 1 :
257
273
alpha_keys = [k for k in network_alphas .keys () if k .startswith (f"{ prefix } ." )]
258
274
network_alphas = {k .replace (f"{ prefix } ." , "" ): v for k , v in network_alphas .items () if k in alpha_keys }
259
275
260
276
lora_config_kwargs = get_peft_kwargs (rank , network_alpha_dict = network_alphas , peft_state_dict = state_dict )
261
- lora_config_kwargs = _maybe_adjust_config (lora_config_kwargs )
277
+ if not FULLY_QUALIFIED_PATTERN_KEY_PREFIX :
278
+ lora_config_kwargs = _maybe_adjust_config (lora_config_kwargs )
262
279
263
280
if "use_dora" in lora_config_kwargs :
264
281
if lora_config_kwargs ["use_dora" ]:
0 commit comments