Skip to content

[LoRA] improve LoRA fusion tests #11274

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 19 commits into from
May 27, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 21 additions & 3 deletions src/diffusers/loaders/lora_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -465,7 +465,7 @@ class LoraBaseMixin:
"""Utility class for handling LoRAs."""

_lora_loadable_modules = []
num_fused_loras = 0
_merged_adapters = set()

def load_lora_weights(self, **kwargs):
raise NotImplementedError("`load_lora_weights()` is not implemented.")
Expand Down Expand Up @@ -592,6 +592,9 @@ def fuse_lora(
if len(components) == 0:
raise ValueError("`components` cannot be an empty list.")

# Need to retrieve the names as `adapter_names` can be None. So we cannot directly use it
# in `self._merged_adapters = self._merged_adapters | merged_adapter_names`.
merged_adapter_names = set()
for fuse_component in components:
if fuse_component not in self._lora_loadable_modules:
raise ValueError(f"{fuse_component} is not found in {self._lora_loadable_modules=}.")
Expand All @@ -601,13 +604,19 @@ def fuse_lora(
# check if diffusers model
if issubclass(model.__class__, ModelMixin):
model.fuse_lora(lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names)
for module in model.modules():
if isinstance(module, BaseTunerLayer):
merged_adapter_names.update(set(module.merged_adapters))
# handle transformers models.
if issubclass(model.__class__, PreTrainedModel):
fuse_text_encoder_lora(
model, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
)
for module in model.modules():
if isinstance(module, BaseTunerLayer):
merged_adapter_names.update(set(module.merged_adapters))

self.num_fused_loras += 1
self._merged_adapters = self._merged_adapters | merged_adapter_names

def unfuse_lora(self, components: List[str] = [], **kwargs):
r"""
Expand Down Expand Up @@ -661,9 +670,18 @@ def unfuse_lora(self, components: List[str] = [], **kwargs):
if issubclass(model.__class__, (ModelMixin, PreTrainedModel)):
for module in model.modules():
if isinstance(module, BaseTunerLayer):
for adapter in set(module.merged_adapters):
if adapter and adapter in self._merged_adapters:
self._merged_adapters = self._merged_adapters - {adapter}
module.unmerge()

self.num_fused_loras -= 1
@property
def num_fused_loras(self):
return len(self._merged_adapters)

@property
def fused_loras(self):
return self._merged_adapters

def set_adapters(
self,
Expand Down
3 changes: 3 additions & 0 deletions tests/lora/test_lora_layers_cogvideox.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,9 @@ def test_simple_inference_with_text_lora_denoiser_fused_multi(self):
def test_simple_inference_with_text_denoiser_lora_unfused(self):
super().test_simple_inference_with_text_denoiser_lora_unfused(expected_atol=9e-3)

def test_lora_scale_kwargs_match_fusion(self):
super().test_lora_scale_kwargs_match_fusion(expected_atol=9e-3, expected_rtol=9e-3)

@unittest.skip("Not supported in CogVideoX.")
def test_simple_inference_with_text_denoiser_block_scale(self):
pass
Expand Down
34 changes: 34 additions & 0 deletions tests/lora/test_lora_layers_sdxl.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,40 @@ def tearDown(self):
def test_multiple_wrong_adapter_name_raises_error(self):
super().test_multiple_wrong_adapter_name_raises_error()

def test_simple_inference_with_text_denoiser_lora_unfused(self):
if torch.cuda.is_available():
expected_atol = 9e-2
expected_rtol = 9e-2
else:
expected_atol = 1e-3
expected_rtol = 1e-3

super().test_simple_inference_with_text_denoiser_lora_unfused(
expected_atol=expected_atol, expected_rtol=expected_rtol
)

def test_simple_inference_with_text_lora_denoiser_fused_multi(self):
if torch.cuda.is_available():
expected_atol = 9e-2
expected_rtol = 9e-2
else:
expected_atol = 1e-3
expected_rtol = 1e-3

super().test_simple_inference_with_text_lora_denoiser_fused_multi(
expected_atol=expected_atol, expected_rtol=expected_rtol
)

def test_lora_scale_kwargs_match_fusion(self):
if torch.cuda.is_available():
expected_atol = 9e-2
expected_rtol = 9e-2
else:
expected_atol = 1e-3
expected_rtol = 1e-3

super().test_lora_scale_kwargs_match_fusion(expected_atol=expected_atol, expected_rtol=expected_rtol)


@slow
@nightly
Expand Down
120 changes: 92 additions & 28 deletions tests/lora/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,18 @@ def initialize_dummy_state_dict(state_dict):
POSSIBLE_ATTENTION_KWARGS_NAMES = ["cross_attention_kwargs", "joint_attention_kwargs", "attention_kwargs"]


def determine_attention_kwargs_name(pipeline_class):
call_signature_keys = inspect.signature(pipeline_class.__call__).parameters.keys()

# TODO(diffusers): Discuss a common naming convention across library for 1.0.0 release
for possible_attention_kwargs in POSSIBLE_ATTENTION_KWARGS_NAMES:
if possible_attention_kwargs in call_signature_keys:
attention_kwargs_name = possible_attention_kwargs
break
assert attention_kwargs_name is not None
return attention_kwargs_name


@require_peft_backend
class PeftLoraLoaderMixinTests:
pipeline_class = None
Expand Down Expand Up @@ -442,14 +454,7 @@ def test_simple_inference_with_text_lora_and_scale(self):
Tests a simple inference with lora attached on the text encoder + scale argument
and makes sure it works as expected
"""
call_signature_keys = inspect.signature(self.pipeline_class.__call__).parameters.keys()

# TODO(diffusers): Discuss a common naming convention across library for 1.0.0 release
for possible_attention_kwargs in POSSIBLE_ATTENTION_KWARGS_NAMES:
if possible_attention_kwargs in call_signature_keys:
attention_kwargs_name = possible_attention_kwargs
break
assert attention_kwargs_name is not None
attention_kwargs_name = determine_attention_kwargs_name(self.pipeline_class)

for scheduler_cls in self.scheduler_classes:
components, text_lora_config, _ = self.get_dummy_components(scheduler_cls)
Expand Down Expand Up @@ -740,12 +745,7 @@ def test_simple_inference_with_text_denoiser_lora_and_scale(self):
Tests a simple inference with lora attached on the text encoder + Unet + scale argument
and makes sure it works as expected
"""
call_signature_keys = inspect.signature(self.pipeline_class.__call__).parameters.keys()
for possible_attention_kwargs in POSSIBLE_ATTENTION_KWARGS_NAMES:
if possible_attention_kwargs in call_signature_keys:
attention_kwargs_name = possible_attention_kwargs
break
assert attention_kwargs_name is not None
attention_kwargs_name = determine_attention_kwargs_name(self.pipeline_class)

for scheduler_cls in self.scheduler_classes:
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
Expand Down Expand Up @@ -878,9 +878,11 @@ def test_simple_inference_with_text_denoiser_lora_unfused(
pipe, denoiser = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config)

pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules)
self.assertTrue(pipe.num_fused_loras == 1, f"{pipe.num_fused_loras=}, {pipe.fused_loras=}")
output_fused_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]

pipe.unfuse_lora(components=self.pipeline_class._lora_loadable_modules)
self.assertTrue(pipe.num_fused_loras == 0, f"{pipe.num_fused_loras=}, {pipe.fused_loras=}")
output_unfused_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]

# unloading should remove the LoRA layers
Expand Down Expand Up @@ -1608,26 +1610,21 @@ def test_simple_inference_with_text_lora_denoiser_fused_multi(
self.assertTrue(
check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
)
pipe.text_encoder.add_adapter(text_lora_config, "adapter-2")

denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
denoiser.add_adapter(denoiser_lora_config, "adapter-1")

# Attach a second adapter
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
pipe.text_encoder.add_adapter(text_lora_config, "adapter-2")

denoiser.add_adapter(denoiser_lora_config, "adapter-2")

self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
denoiser.add_adapter(denoiser_lora_config, "adapter-2")

if self.has_two_text_encoders or self.has_three_text_encoders:
lora_loadable_components = self.pipeline_class._lora_loadable_modules
if "text_encoder_2" in lora_loadable_components:
pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1")
pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-2")
self.assertTrue(
check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
)
pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-2")

# set them to multi-adapter inference mode
pipe.set_adapters(["adapter-1", "adapter-2"])
Expand All @@ -1637,6 +1634,7 @@ def test_simple_inference_with_text_lora_denoiser_fused_multi(
outputs_lora_1 = pipe(**inputs, generator=torch.manual_seed(0))[0]

pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, adapter_names=["adapter-1"])
self.assertTrue(pipe.num_fused_loras == 1, f"{pipe.num_fused_loras=}, {pipe.fused_loras=}")

# Fusing should still keep the LoRA layers so outpout should remain the same
outputs_lora_1_fused = pipe(**inputs, generator=torch.manual_seed(0))[0]
Expand All @@ -1647,16 +1645,87 @@ def test_simple_inference_with_text_lora_denoiser_fused_multi(
)

pipe.unfuse_lora(components=self.pipeline_class._lora_loadable_modules)
self.assertTrue(pipe.num_fused_loras == 0, f"{pipe.num_fused_loras=}, {pipe.fused_loras=}")

if "text_encoder" in self.pipeline_class._lora_loadable_modules:
self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Unfuse should still keep LoRA layers")

self.assertTrue(check_if_lora_correctly_set(denoiser), "Unfuse should still keep LoRA layers")

if self.has_two_text_encoders or self.has_three_text_encoders:
if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
self.assertTrue(
check_if_lora_correctly_set(pipe.text_encoder_2), "Unfuse should still keep LoRA layers"
)

pipe.fuse_lora(
components=self.pipeline_class._lora_loadable_modules, adapter_names=["adapter-2", "adapter-1"]
)
self.assertTrue(pipe.num_fused_loras == 2, f"{pipe.num_fused_loras=}, {pipe.fused_loras=}")

# Fusing should still keep the LoRA layers
output_all_lora_fused = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(
np.allclose(output_all_lora_fused, outputs_all_lora, atol=expected_atol, rtol=expected_rtol),
"Fused lora should not change the output",
)
pipe.unfuse_lora(components=self.pipeline_class._lora_loadable_modules)
self.assertTrue(pipe.num_fused_loras == 0, f"{pipe.num_fused_loras=}, {pipe.fused_loras=}")

def test_lora_scale_kwargs_match_fusion(self, expected_atol: float = 1e-3, expected_rtol: float = 1e-3):
attention_kwargs_name = determine_attention_kwargs_name(self.pipeline_class)

for lora_scale in [1.0, 0.8]:
for scheduler_cls in self.scheduler_classes:
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)

output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(output_no_lora.shape == self.output_shape)

if "text_encoder" in self.pipeline_class._lora_loadable_modules:
pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
self.assertTrue(
check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
)

denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
denoiser.add_adapter(denoiser_lora_config, "adapter-1")
self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")

if self.has_two_text_encoders or self.has_three_text_encoders:
lora_loadable_components = self.pipeline_class._lora_loadable_modules
if "text_encoder_2" in lora_loadable_components:
pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1")
self.assertTrue(
check_if_lora_correctly_set(pipe.text_encoder_2),
"Lora not correctly set in text encoder 2",
)

pipe.set_adapters(["adapter-1"])
attention_kwargs = {attention_kwargs_name: {"scale": lora_scale}}
outputs_lora_1 = pipe(**inputs, generator=torch.manual_seed(0), **attention_kwargs)[0]

pipe.fuse_lora(
components=self.pipeline_class._lora_loadable_modules,
adapter_names=["adapter-1"],
lora_scale=lora_scale,
)
self.assertTrue(pipe.num_fused_loras == 1, f"{pipe.num_fused_loras=}, {pipe.fused_loras=}")

outputs_lora_1_fused = pipe(**inputs, generator=torch.manual_seed(0))[0]

self.assertTrue(
np.allclose(outputs_lora_1, outputs_lora_1_fused, atol=expected_atol, rtol=expected_rtol),
"Fused lora should not change the output",
)
self.assertFalse(
np.allclose(output_no_lora, outputs_lora_1, atol=expected_atol, rtol=expected_rtol),
"LoRA should change the output",
)

@require_peft_version_greater(peft_version="0.9.0")
def test_simple_inference_with_dora(self):
Expand Down Expand Up @@ -1838,12 +1907,7 @@ def test_logs_info_when_no_lora_keys_found(self):

def test_set_adapters_match_attention_kwargs(self):
"""Test to check if outputs after `set_adapters()` and attention kwargs match."""
call_signature_keys = inspect.signature(self.pipeline_class.__call__).parameters.keys()
for possible_attention_kwargs in POSSIBLE_ATTENTION_KWARGS_NAMES:
if possible_attention_kwargs in call_signature_keys:
attention_kwargs_name = possible_attention_kwargs
break
assert attention_kwargs_name is not None
attention_kwargs_name = determine_attention_kwargs_name(self.pipeline_class)

for scheduler_cls in self.scheduler_classes:
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
Expand Down