29
29
convert_unet_state_dict_to_peft ,
30
30
delete_adapter_layers ,
31
31
get_adapter_name ,
32
- get_peft_kwargs ,
33
32
is_peft_available ,
34
33
is_peft_version ,
35
34
logging ,
36
35
set_adapter_layers ,
37
36
set_weights_and_activate_adapters ,
38
37
)
38
+ from ..utils .peft_utils import _create_lora_config , _maybe_warn_for_unhandled_keys
39
39
from .lora_base import _fetch_state_dict , _func_optionally_disable_offloading
40
40
from .unet_loader_utils import _maybe_expand_lora_scales
41
41
64
64
}
65
65
66
66
67
- def _maybe_raise_error_for_ambiguity (config ):
68
- rank_pattern = config ["rank_pattern" ].copy ()
69
- target_modules = config ["target_modules" ]
70
-
71
- for key in list (rank_pattern .keys ()):
72
- # try to detect ambiguity
73
- # `target_modules` can also be a str, in which case this loop would loop
74
- # over the chars of the str. The technically correct way to match LoRA keys
75
- # in PEFT is to use LoraModel._check_target_module_exists (lora_config, key).
76
- # But this cuts it for now.
77
- exact_matches = [mod for mod in target_modules if mod == key ]
78
- substring_matches = [mod for mod in target_modules if key in mod and mod != key ]
79
-
80
- if exact_matches and substring_matches :
81
- if is_peft_version ("<" , "0.14.1" ):
82
- raise ValueError (
83
- "There are ambiguous keys present in this LoRA. To load it, please update your `peft` installation - `pip install -U peft`."
84
- )
85
-
86
-
87
67
class PeftAdapterMixin :
88
68
"""
89
69
A class containing all functions for loading and using adapters weights that are supported in PEFT library. For
@@ -191,7 +171,7 @@ def load_lora_adapter(
191
171
LoRA adapter metadata. When supplied, the metadata inferred through the state dict isn't used to
192
172
initialize `LoraConfig`.
193
173
"""
194
- from peft import LoraConfig , inject_adapter_in_model , set_peft_model_state_dict
174
+ from peft import inject_adapter_in_model , set_peft_model_state_dict
195
175
from peft .tuners .tuners_utils import BaseTunerLayer
196
176
197
177
cache_dir = kwargs .pop ("cache_dir" , None )
@@ -216,7 +196,6 @@ def load_lora_adapter(
216
196
)
217
197
218
198
user_agent = {"file_type" : "attn_procs_weights" , "framework" : "pytorch" }
219
-
220
199
state_dict , metadata = _fetch_state_dict (
221
200
pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict ,
222
201
weight_name = weight_name ,
@@ -275,38 +254,8 @@ def load_lora_adapter(
275
254
k .removeprefix (f"{ prefix } ." ): v for k , v in network_alphas .items () if k in alpha_keys
276
255
}
277
256
278
- if metadata is not None :
279
- lora_config_kwargs = metadata
280
- else :
281
- lora_config_kwargs = get_peft_kwargs (
282
- rank , network_alpha_dict = network_alphas , peft_state_dict = state_dict
283
- )
284
- _maybe_raise_error_for_ambiguity (lora_config_kwargs )
285
-
286
- if "use_dora" in lora_config_kwargs :
287
- if lora_config_kwargs ["use_dora" ]:
288
- if is_peft_version ("<" , "0.9.0" ):
289
- raise ValueError (
290
- "You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
291
- )
292
- else :
293
- if is_peft_version ("<" , "0.9.0" ):
294
- lora_config_kwargs .pop ("use_dora" )
295
-
296
- if "lora_bias" in lora_config_kwargs :
297
- if lora_config_kwargs ["lora_bias" ]:
298
- if is_peft_version ("<=" , "0.13.2" ):
299
- raise ValueError (
300
- "You need `peft` 0.14.0 at least to use `lora_bias` in LoRAs. Please upgrade your installation of `peft`."
301
- )
302
- else :
303
- if is_peft_version ("<=" , "0.13.2" ):
304
- lora_config_kwargs .pop ("lora_bias" )
305
-
306
- try :
307
- lora_config = LoraConfig (** lora_config_kwargs )
308
- except TypeError as e :
309
- raise TypeError ("`LoraConfig` class could not be instantiated." ) from e
257
+ # create LoraConfig
258
+ lora_config = _create_lora_config (state_dict , network_alphas , metadata , rank )
310
259
311
260
# adapter_name
312
261
if adapter_name is None :
@@ -317,9 +266,8 @@ def load_lora_adapter(
317
266
# Now we remove any existing hooks to `_pipeline`.
318
267
319
268
# In case the pipeline has been already offloaded to CPU - temporarily remove the hooks
320
- # otherwise loading LoRA weights will lead to an error
269
+ # otherwise loading LoRA weights will lead to an error.
321
270
is_model_cpu_offload , is_sequential_cpu_offload = self ._optionally_disable_offloading (_pipeline )
322
-
323
271
peft_kwargs = {}
324
272
if is_peft_version (">=" , "0.13.1" ):
325
273
peft_kwargs ["low_cpu_mem_usage" ] = low_cpu_mem_usage
@@ -403,30 +351,7 @@ def map_state_dict_for_hotswap(sd):
403
351
logger .error (f"Loading { adapter_name } was unsuccessful with the following error: \n { e } " )
404
352
raise
405
353
406
- warn_msg = ""
407
- if incompatible_keys is not None :
408
- # Check only for unexpected keys.
409
- unexpected_keys = getattr (incompatible_keys , "unexpected_keys" , None )
410
- if unexpected_keys :
411
- lora_unexpected_keys = [k for k in unexpected_keys if "lora_" in k and adapter_name in k ]
412
- if lora_unexpected_keys :
413
- warn_msg = (
414
- f"Loading adapter weights from state_dict led to unexpected keys found in the model:"
415
- f" { ', ' .join (lora_unexpected_keys )} . "
416
- )
417
-
418
- # Filter missing keys specific to the current adapter.
419
- missing_keys = getattr (incompatible_keys , "missing_keys" , None )
420
- if missing_keys :
421
- lora_missing_keys = [k for k in missing_keys if "lora_" in k and adapter_name in k ]
422
- if lora_missing_keys :
423
- warn_msg += (
424
- f"Loading adapter weights from state_dict led to missing keys in the model:"
425
- f" { ', ' .join (lora_missing_keys )} ."
426
- )
427
-
428
- if warn_msg :
429
- logger .warning (warn_msg )
354
+ _maybe_warn_for_unhandled_keys (incompatible_keys , adapter_name )
430
355
431
356
# Offload back.
432
357
if is_model_cpu_offload :
@@ -436,10 +361,11 @@ def map_state_dict_for_hotswap(sd):
436
361
# Unsafe code />
437
362
438
363
if prefix is not None and not state_dict :
364
+ model_class_name = self .__class__ .__name__
439
365
logger .warning (
440
- f"No LoRA keys associated to { self . __class__ . __name__ } found with the { prefix = } . "
366
+ f"No LoRA keys associated to { model_class_name } found with the { prefix = } . "
441
367
"This is safe to ignore if LoRA state dict didn't originally have any "
442
- f"{ self . __class__ . __name__ } related params. You can also try specifying `prefix=None` "
368
+ f"{ model_class_name } related params. You can also try specifying `prefix=None` "
443
369
"to resolve the warning. Otherwise, open an issue if you think it's unexpected: "
444
370
"https://github.com/huggingface/diffusers/issues/new"
445
371
)
0 commit comments