Skip to content

Commit fb57c76

Browse files
authored
[LoRA] refactor lora loading at the model-level (#11719)
* factor out stuff from load_lora_adapter(). * simplifying text encoder lora loading. * fix peft.py * fix logging locations. * formatting * fix * update * update * update
1 parent 7251bb4 commit fb57c76

File tree

4 files changed

+107
-140
lines changed

4 files changed

+107
-140
lines changed

src/diffusers/loaders/lora_base.py

Lines changed: 12 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,6 @@
3434
delete_adapter_layers,
3535
deprecate,
3636
get_adapter_name,
37-
get_peft_kwargs,
3837
is_accelerate_available,
3938
is_peft_available,
4039
is_peft_version,
@@ -46,14 +45,13 @@
4645
set_adapter_layers,
4746
set_weights_and_activate_adapters,
4847
)
48+
from ..utils.peft_utils import _create_lora_config
4949
from ..utils.state_dict_utils import _load_sft_state_dict_metadata
5050

5151

5252
if is_transformers_available():
5353
from transformers import PreTrainedModel
5454

55-
from ..models.lora import text_encoder_attn_modules, text_encoder_mlp_modules
56-
5755
if is_peft_available():
5856
from peft.tuners.tuners_utils import BaseTunerLayer
5957

@@ -352,8 +350,6 @@ def _load_lora_into_text_encoder(
352350
)
353351
peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
354352

355-
from peft import LoraConfig
356-
357353
# If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
358354
# then the `state_dict` keys should have `unet_name` and/or `text_encoder_name` as
359355
# their prefixes.
@@ -377,60 +373,25 @@ def _load_lora_into_text_encoder(
377373
# convert state dict
378374
state_dict = convert_state_dict_to_peft(state_dict)
379375

380-
for name, _ in text_encoder_attn_modules(text_encoder):
381-
for module in ("out_proj", "q_proj", "k_proj", "v_proj"):
382-
rank_key = f"{name}.{module}.lora_B.weight"
383-
if rank_key not in state_dict:
384-
continue
385-
rank[rank_key] = state_dict[rank_key].shape[1]
386-
387-
for name, _ in text_encoder_mlp_modules(text_encoder):
388-
for module in ("fc1", "fc2"):
389-
rank_key = f"{name}.{module}.lora_B.weight"
390-
if rank_key not in state_dict:
391-
continue
392-
rank[rank_key] = state_dict[rank_key].shape[1]
376+
for name, _ in text_encoder.named_modules():
377+
if name.endswith((".q_proj", ".k_proj", ".v_proj", ".out_proj", ".fc1", ".fc2")):
378+
rank_key = f"{name}.lora_B.weight"
379+
if rank_key in state_dict:
380+
rank[rank_key] = state_dict[rank_key].shape[1]
393381

394382
if network_alphas is not None:
395383
alpha_keys = [k for k in network_alphas.keys() if k.startswith(prefix) and k.split(".")[0] == prefix]
396384
network_alphas = {k.removeprefix(f"{prefix}."): v for k, v in network_alphas.items() if k in alpha_keys}
397385

398-
if metadata is not None:
399-
lora_config_kwargs = metadata
400-
else:
401-
lora_config_kwargs = get_peft_kwargs(rank, network_alphas, state_dict, is_unet=False)
402-
403-
if "use_dora" in lora_config_kwargs:
404-
if lora_config_kwargs["use_dora"]:
405-
if is_peft_version("<", "0.9.0"):
406-
raise ValueError(
407-
"You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
408-
)
409-
else:
410-
if is_peft_version("<", "0.9.0"):
411-
lora_config_kwargs.pop("use_dora")
412-
413-
if "lora_bias" in lora_config_kwargs:
414-
if lora_config_kwargs["lora_bias"]:
415-
if is_peft_version("<=", "0.13.2"):
416-
raise ValueError(
417-
"You need `peft` 0.14.0 at least to use `bias` in LoRAs. Please upgrade your installation of `peft`."
418-
)
419-
else:
420-
if is_peft_version("<=", "0.13.2"):
421-
lora_config_kwargs.pop("lora_bias")
422-
423-
try:
424-
lora_config = LoraConfig(**lora_config_kwargs)
425-
except TypeError as e:
426-
raise TypeError("`LoraConfig` class could not be instantiated.") from e
386+
# create `LoraConfig`
387+
lora_config = _create_lora_config(state_dict, network_alphas, metadata, rank, is_unet=False)
427388

428389
# adapter_name
429390
if adapter_name is None:
430391
adapter_name = get_adapter_name(text_encoder)
431392

393+
# <Unsafe code
432394
is_model_cpu_offload, is_sequential_cpu_offload = _func_optionally_disable_offloading(_pipeline)
433-
434395
# inject LoRA layers and load the state dict
435396
# in transformers we automatically check whether the adapter name is already in use or not
436397
text_encoder.load_adapter(
@@ -442,7 +403,6 @@ def _load_lora_into_text_encoder(
442403

443404
# scale LoRA layers with `lora_scale`
444405
scale_lora_layers(text_encoder, weight=lora_scale)
445-
446406
text_encoder.to(device=text_encoder.device, dtype=text_encoder.dtype)
447407

448408
# Offload back.
@@ -453,10 +413,11 @@ def _load_lora_into_text_encoder(
453413
# Unsafe code />
454414

455415
if prefix is not None and not state_dict:
416+
model_class_name = text_encoder.__class__.__name__
456417
logger.warning(
457-
f"No LoRA keys associated to {text_encoder.__class__.__name__} found with the {prefix=}. "
418+
f"No LoRA keys associated to {model_class_name} found with the {prefix=}. "
458419
"This is safe to ignore if LoRA state dict didn't originally have any "
459-
f"{text_encoder.__class__.__name__} related params. You can also try specifying `prefix=None` "
420+
f"{model_class_name} related params. You can also try specifying `prefix=None` "
460421
"to resolve the warning. Otherwise, open an issue if you think it's unexpected: "
461422
"https://github.com/huggingface/diffusers/issues/new"
462423
)

src/diffusers/loaders/peft.py

Lines changed: 9 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -29,13 +29,13 @@
2929
convert_unet_state_dict_to_peft,
3030
delete_adapter_layers,
3131
get_adapter_name,
32-
get_peft_kwargs,
3332
is_peft_available,
3433
is_peft_version,
3534
logging,
3635
set_adapter_layers,
3736
set_weights_and_activate_adapters,
3837
)
38+
from ..utils.peft_utils import _create_lora_config, _maybe_warn_for_unhandled_keys
3939
from .lora_base import _fetch_state_dict, _func_optionally_disable_offloading
4040
from .unet_loader_utils import _maybe_expand_lora_scales
4141

@@ -64,26 +64,6 @@
6464
}
6565

6666

67-
def _maybe_raise_error_for_ambiguity(config):
68-
rank_pattern = config["rank_pattern"].copy()
69-
target_modules = config["target_modules"]
70-
71-
for key in list(rank_pattern.keys()):
72-
# try to detect ambiguity
73-
# `target_modules` can also be a str, in which case this loop would loop
74-
# over the chars of the str. The technically correct way to match LoRA keys
75-
# in PEFT is to use LoraModel._check_target_module_exists (lora_config, key).
76-
# But this cuts it for now.
77-
exact_matches = [mod for mod in target_modules if mod == key]
78-
substring_matches = [mod for mod in target_modules if key in mod and mod != key]
79-
80-
if exact_matches and substring_matches:
81-
if is_peft_version("<", "0.14.1"):
82-
raise ValueError(
83-
"There are ambiguous keys present in this LoRA. To load it, please update your `peft` installation - `pip install -U peft`."
84-
)
85-
86-
8767
class PeftAdapterMixin:
8868
"""
8969
A class containing all functions for loading and using adapters weights that are supported in PEFT library. For
@@ -191,7 +171,7 @@ def load_lora_adapter(
191171
LoRA adapter metadata. When supplied, the metadata inferred through the state dict isn't used to
192172
initialize `LoraConfig`.
193173
"""
194-
from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict
174+
from peft import inject_adapter_in_model, set_peft_model_state_dict
195175
from peft.tuners.tuners_utils import BaseTunerLayer
196176

197177
cache_dir = kwargs.pop("cache_dir", None)
@@ -216,7 +196,6 @@ def load_lora_adapter(
216196
)
217197

218198
user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"}
219-
220199
state_dict, metadata = _fetch_state_dict(
221200
pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
222201
weight_name=weight_name,
@@ -275,38 +254,8 @@ def load_lora_adapter(
275254
k.removeprefix(f"{prefix}."): v for k, v in network_alphas.items() if k in alpha_keys
276255
}
277256

278-
if metadata is not None:
279-
lora_config_kwargs = metadata
280-
else:
281-
lora_config_kwargs = get_peft_kwargs(
282-
rank, network_alpha_dict=network_alphas, peft_state_dict=state_dict
283-
)
284-
_maybe_raise_error_for_ambiguity(lora_config_kwargs)
285-
286-
if "use_dora" in lora_config_kwargs:
287-
if lora_config_kwargs["use_dora"]:
288-
if is_peft_version("<", "0.9.0"):
289-
raise ValueError(
290-
"You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
291-
)
292-
else:
293-
if is_peft_version("<", "0.9.0"):
294-
lora_config_kwargs.pop("use_dora")
295-
296-
if "lora_bias" in lora_config_kwargs:
297-
if lora_config_kwargs["lora_bias"]:
298-
if is_peft_version("<=", "0.13.2"):
299-
raise ValueError(
300-
"You need `peft` 0.14.0 at least to use `lora_bias` in LoRAs. Please upgrade your installation of `peft`."
301-
)
302-
else:
303-
if is_peft_version("<=", "0.13.2"):
304-
lora_config_kwargs.pop("lora_bias")
305-
306-
try:
307-
lora_config = LoraConfig(**lora_config_kwargs)
308-
except TypeError as e:
309-
raise TypeError("`LoraConfig` class could not be instantiated.") from e
257+
# create LoraConfig
258+
lora_config = _create_lora_config(state_dict, network_alphas, metadata, rank)
310259

311260
# adapter_name
312261
if adapter_name is None:
@@ -317,9 +266,8 @@ def load_lora_adapter(
317266
# Now we remove any existing hooks to `_pipeline`.
318267

319268
# In case the pipeline has been already offloaded to CPU - temporarily remove the hooks
320-
# otherwise loading LoRA weights will lead to an error
269+
# otherwise loading LoRA weights will lead to an error.
321270
is_model_cpu_offload, is_sequential_cpu_offload = self._optionally_disable_offloading(_pipeline)
322-
323271
peft_kwargs = {}
324272
if is_peft_version(">=", "0.13.1"):
325273
peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
@@ -403,30 +351,7 @@ def map_state_dict_for_hotswap(sd):
403351
logger.error(f"Loading {adapter_name} was unsuccessful with the following error: \n{e}")
404352
raise
405353

406-
warn_msg = ""
407-
if incompatible_keys is not None:
408-
# Check only for unexpected keys.
409-
unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
410-
if unexpected_keys:
411-
lora_unexpected_keys = [k for k in unexpected_keys if "lora_" in k and adapter_name in k]
412-
if lora_unexpected_keys:
413-
warn_msg = (
414-
f"Loading adapter weights from state_dict led to unexpected keys found in the model:"
415-
f" {', '.join(lora_unexpected_keys)}. "
416-
)
417-
418-
# Filter missing keys specific to the current adapter.
419-
missing_keys = getattr(incompatible_keys, "missing_keys", None)
420-
if missing_keys:
421-
lora_missing_keys = [k for k in missing_keys if "lora_" in k and adapter_name in k]
422-
if lora_missing_keys:
423-
warn_msg += (
424-
f"Loading adapter weights from state_dict led to missing keys in the model:"
425-
f" {', '.join(lora_missing_keys)}."
426-
)
427-
428-
if warn_msg:
429-
logger.warning(warn_msg)
354+
_maybe_warn_for_unhandled_keys(incompatible_keys, adapter_name)
430355

431356
# Offload back.
432357
if is_model_cpu_offload:
@@ -436,10 +361,11 @@ def map_state_dict_for_hotswap(sd):
436361
# Unsafe code />
437362

438363
if prefix is not None and not state_dict:
364+
model_class_name = self.__class__.__name__
439365
logger.warning(
440-
f"No LoRA keys associated to {self.__class__.__name__} found with the {prefix=}. "
366+
f"No LoRA keys associated to {model_class_name} found with the {prefix=}. "
441367
"This is safe to ignore if LoRA state dict didn't originally have any "
442-
f"{self.__class__.__name__} related params. You can also try specifying `prefix=None` "
368+
f"{model_class_name} related params. You can also try specifying `prefix=None` "
443369
"to resolve the warning. Otherwise, open an issue if you think it's unexpected: "
444370
"https://github.com/huggingface/diffusers/issues/new"
445371
)

src/diffusers/utils/peft_utils.py

Lines changed: 84 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,12 @@
2121

2222
from packaging import version
2323

24-
from .import_utils import is_peft_available, is_torch_available
24+
from . import logging
25+
from .import_utils import is_peft_available, is_peft_version, is_torch_available
2526

2627

28+
logger = logging.get_logger(__name__)
29+
2730
if is_torch_available():
2831
import torch
2932

@@ -288,3 +291,83 @@ def check_peft_version(min_version: str) -> None:
288291
f"The version of PEFT you are using is not compatible, please use a version that is greater"
289292
f" than {min_version}"
290293
)
294+
295+
296+
def _create_lora_config(
297+
state_dict,
298+
network_alphas,
299+
metadata,
300+
rank_pattern_dict,
301+
is_unet: bool = True,
302+
):
303+
from peft import LoraConfig
304+
305+
if metadata is not None:
306+
lora_config_kwargs = metadata
307+
else:
308+
lora_config_kwargs = get_peft_kwargs(
309+
rank_pattern_dict, network_alpha_dict=network_alphas, peft_state_dict=state_dict, is_unet=is_unet
310+
)
311+
312+
_maybe_raise_error_for_ambiguous_keys(lora_config_kwargs)
313+
314+
# Version checks for DoRA and lora_bias
315+
if "use_dora" in lora_config_kwargs and lora_config_kwargs["use_dora"]:
316+
if is_peft_version("<", "0.9.0"):
317+
raise ValueError("DoRA requires PEFT >= 0.9.0. Please upgrade.")
318+
319+
if "lora_bias" in lora_config_kwargs and lora_config_kwargs["lora_bias"]:
320+
if is_peft_version("<=", "0.13.2"):
321+
raise ValueError("lora_bias requires PEFT >= 0.14.0. Please upgrade.")
322+
323+
try:
324+
return LoraConfig(**lora_config_kwargs)
325+
except TypeError as e:
326+
raise TypeError("`LoraConfig` class could not be instantiated.") from e
327+
328+
329+
def _maybe_raise_error_for_ambiguous_keys(config):
330+
rank_pattern = config["rank_pattern"].copy()
331+
target_modules = config["target_modules"]
332+
333+
for key in list(rank_pattern.keys()):
334+
# try to detect ambiguity
335+
# `target_modules` can also be a str, in which case this loop would loop
336+
# over the chars of the str. The technically correct way to match LoRA keys
337+
# in PEFT is to use LoraModel._check_target_module_exists (lora_config, key).
338+
# But this cuts it for now.
339+
exact_matches = [mod for mod in target_modules if mod == key]
340+
substring_matches = [mod for mod in target_modules if key in mod and mod != key]
341+
342+
if exact_matches and substring_matches:
343+
if is_peft_version("<", "0.14.1"):
344+
raise ValueError(
345+
"There are ambiguous keys present in this LoRA. To load it, please update your `peft` installation - `pip install -U peft`."
346+
)
347+
348+
349+
def _maybe_warn_for_unhandled_keys(incompatible_keys, adapter_name):
350+
warn_msg = ""
351+
if incompatible_keys is not None:
352+
# Check only for unexpected keys.
353+
unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
354+
if unexpected_keys:
355+
lora_unexpected_keys = [k for k in unexpected_keys if "lora_" in k and adapter_name in k]
356+
if lora_unexpected_keys:
357+
warn_msg = (
358+
f"Loading adapter weights from state_dict led to unexpected keys found in the model:"
359+
f" {', '.join(lora_unexpected_keys)}. "
360+
)
361+
362+
# Filter missing keys specific to the current adapter.
363+
missing_keys = getattr(incompatible_keys, "missing_keys", None)
364+
if missing_keys:
365+
lora_missing_keys = [k for k in missing_keys if "lora_" in k and adapter_name in k]
366+
if lora_missing_keys:
367+
warn_msg += (
368+
f"Loading adapter weights from state_dict led to missing keys in the model:"
369+
f" {', '.join(lora_missing_keys)}."
370+
)
371+
372+
if warn_msg:
373+
logger.warning(warn_msg)

0 commit comments

Comments
 (0)