diff --git a/.github/workflows/pr_tests.yml b/.github/workflows/pr_tests.yml index 34a344528e3e..83a7c69e8bc3 100644 --- a/.github/workflows/pr_tests.yml +++ b/.github/workflows/pr_tests.yml @@ -265,11 +265,11 @@ jobs: - name: Run fast PyTorch LoRA tests with PEFT run: | python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH" - python -m pytest -n 4 --max-worker-restart=0 --dist=loadfile \ + python -m pytest -n 6 --max-worker-restart=0 --dist=loadfile \ -s -v \ --make-reports=tests_peft_main \ tests/lora/ - python -m pytest -n 4 --max-worker-restart=0 --dist=loadfile \ + python -m pytest -n 6 --max-worker-restart=0 --dist=loadfile \ -s -v \ --make-reports=tests_models_lora_peft_main \ tests/models/ -k "lora" diff --git a/tests/lora/test_lora_layers_auraflow.py b/tests/lora/test_lora_layers_auraflow.py index d119feae20d0..45a4c4834a16 100644 --- a/tests/lora/test_lora_layers_auraflow.py +++ b/tests/lora/test_lora_layers_auraflow.py @@ -103,34 +103,6 @@ def get_dummy_inputs(self, with_generator=True): return noise, input_ids, pipeline_inputs - @unittest.skip("Not supported in AuraFlow.") - def test_simple_inference_with_text_denoiser_block_scale(self): - pass - - @unittest.skip("Not supported in AuraFlow.") - def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self): - pass - @unittest.skip("Not supported in AuraFlow.") def test_modify_padding_mode(self): pass - - @unittest.skip("Text encoder LoRA is not supported in AuraFlow.") - def test_simple_inference_with_partial_text_lora(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in AuraFlow.") - def test_simple_inference_with_text_lora(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in AuraFlow.") - def test_simple_inference_with_text_lora_and_scale(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in AuraFlow.") - def test_simple_inference_with_text_lora_fused(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in AuraFlow.") - def test_simple_inference_with_text_lora_save_load(self): - pass diff --git a/tests/lora/test_lora_layers_cogvideox.py b/tests/lora/test_lora_layers_cogvideox.py index 565d6db69727..902fbaf631e0 100644 --- a/tests/lora/test_lora_layers_cogvideox.py +++ b/tests/lora/test_lora_layers_cogvideox.py @@ -120,11 +120,25 @@ def get_dummy_inputs(self, with_generator=True): return noise, input_ids, pipeline_inputs - def test_simple_inference_with_text_lora_denoiser_fused_multi(self): - super().test_simple_inference_with_text_lora_denoiser_fused_multi(expected_atol=9e-3) - - def test_simple_inference_with_text_denoiser_lora_unfused(self): - super().test_simple_inference_with_text_denoiser_lora_unfused(expected_atol=9e-3) + @parameterized.expand([("simple",), ("weighted",), ("block_lora",), ("delete_adapter",)]) + def test_lora_set_adapters_scenarios(self, scenario): + super()._test_lora_set_adapters_scenarios(scenario, expected_atol=9e-3) + + @parameterized.expand( + [ + # Test actions on text_encoder LoRA only + ("fused", "text_encoder_only"), + ("unloaded", "text_encoder_only"), + ("save_load", "text_encoder_only"), + # Test actions on both text_encoder and denoiser LoRA + ("fused", "text_and_denoiser"), + ("unloaded", "text_and_denoiser"), + ("unfused", "text_and_denoiser"), + ("save_load", "text_and_denoiser"), + ] + ) + def test_lora_actions(self, action, components_to_add): + super()._test_lora_actions(action, components_to_add, expected_atol=9e-3) def test_lora_scale_kwargs_match_fusion(self): super().test_lora_scale_kwargs_match_fusion(expected_atol=9e-3, expected_rtol=9e-3) @@ -136,38 +150,8 @@ def test_group_offloading_inference_denoiser(self, offload_type, use_stream): # The reason for this can be found here: https://github.com/huggingface/diffusers/pull/11804#issuecomment-3013325338 super()._test_group_offloading_inference_denoiser(offload_type, use_stream) - @unittest.skip("Not supported in CogVideoX.") - def test_simple_inference_with_text_denoiser_block_scale(self): - pass - - @unittest.skip("Not supported in CogVideoX.") - def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self): - pass - @unittest.skip("Not supported in CogVideoX.") def test_modify_padding_mode(self): pass - @unittest.skip("Text encoder LoRA is not supported in CogVideoX.") - def test_simple_inference_with_partial_text_lora(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in CogVideoX.") - def test_simple_inference_with_text_lora(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in CogVideoX.") - def test_simple_inference_with_text_lora_and_scale(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in CogVideoX.") - def test_simple_inference_with_text_lora_fused(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in CogVideoX.") - def test_simple_inference_with_text_lora_save_load(self): - pass - - @unittest.skip("Not supported in CogVideoX.") - def test_simple_inference_with_text_denoiser_multi_adapter_block_lora(self): - pass + # TODO: skip them properly diff --git a/tests/lora/test_lora_layers_cogview4.py b/tests/lora/test_lora_layers_cogview4.py index b7367d9b0946..201e0984ff1e 100644 --- a/tests/lora/test_lora_layers_cogview4.py +++ b/tests/lora/test_lora_layers_cogview4.py @@ -13,10 +13,8 @@ # limitations under the License. import sys -import tempfile import unittest -import numpy as np import torch from parameterized import parameterized from transformers import AutoTokenizer, GlmModel @@ -27,7 +25,6 @@ require_peft_backend, require_torch_accelerator, skip_mps, - torch_device, ) @@ -113,40 +110,21 @@ def get_dummy_inputs(self, with_generator=True): return noise, input_ids, pipeline_inputs - def test_simple_inference_with_text_lora_denoiser_fused_multi(self): - super().test_simple_inference_with_text_lora_denoiser_fused_multi(expected_atol=9e-3) - - def test_simple_inference_with_text_denoiser_lora_unfused(self): - super().test_simple_inference_with_text_denoiser_lora_unfused(expected_atol=9e-3) - - def test_simple_inference_save_pretrained(self): - """ - Tests a simple usecase where users could use saving utilities for LoRA through save_pretrained - """ - for scheduler_cls in self.scheduler_classes: - components, _, _ = self.get_dummy_components(scheduler_cls) - pipe = self.pipeline_class(**components) - pipe = pipe.to(torch_device) - pipe.set_progress_bar_config(disable=None) - _, _, inputs = self.get_dummy_inputs(with_generator=False) - - output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] - self.assertTrue(output_no_lora.shape == self.output_shape) - - images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] - - with tempfile.TemporaryDirectory() as tmpdirname: - pipe.save_pretrained(tmpdirname) - - pipe_from_pretrained = self.pipeline_class.from_pretrained(tmpdirname) - pipe_from_pretrained.to(torch_device) - - images_lora_save_pretrained = pipe_from_pretrained(**inputs, generator=torch.manual_seed(0))[0] - - self.assertTrue( - np.allclose(images_lora, images_lora_save_pretrained, atol=1e-3, rtol=1e-3), - "Loading from saved checkpoints should give same results.", - ) + @parameterized.expand( + [ + # Test actions on text_encoder LoRA only + ("fused", "text_encoder_only"), + ("unloaded", "text_encoder_only"), + ("save_load", "text_encoder_only"), + # Test actions on both text_encoder and denoiser LoRA + ("fused", "text_and_denoiser"), + ("unloaded", "text_and_denoiser"), + ("unfused", "text_and_denoiser"), + ("save_load", "text_and_denoiser"), + ] + ) + def test_lora_actions(self, action, components_to_add): + super()._test_lora_actions(action, components_to_add, expected_atol=9e-3) @parameterized.expand([("block_level", True), ("leaf_level", False)]) @require_torch_accelerator @@ -155,34 +133,6 @@ def test_group_offloading_inference_denoiser(self, offload_type, use_stream): # The reason for this can be found here: https://github.com/huggingface/diffusers/pull/11804#issuecomment-3013325338 super()._test_group_offloading_inference_denoiser(offload_type, use_stream) - @unittest.skip("Not supported in CogView4.") - def test_simple_inference_with_text_denoiser_block_scale(self): - pass - - @unittest.skip("Not supported in CogView4.") - def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self): - pass - @unittest.skip("Not supported in CogView4.") def test_modify_padding_mode(self): pass - - @unittest.skip("Text encoder LoRA is not supported in CogView4.") - def test_simple_inference_with_partial_text_lora(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in CogView4.") - def test_simple_inference_with_text_lora(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in CogView4.") - def test_simple_inference_with_text_lora_and_scale(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in CogView4.") - def test_simple_inference_with_text_lora_fused(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in CogView4.") - def test_simple_inference_with_text_lora_save_load(self): - pass diff --git a/tests/lora/test_lora_layers_flux.py b/tests/lora/test_lora_layers_flux.py index 336ac2246fd2..4cf40e16c60a 100644 --- a/tests/lora/test_lora_layers_flux.py +++ b/tests/lora/test_lora_layers_flux.py @@ -263,21 +263,11 @@ def test_lora_expansion_works_for_extra_keys(self): "LoRA should lead to different results.", ) - @unittest.skip("Not supported in Flux.") - def test_simple_inference_with_text_denoiser_block_scale(self): - pass - - @unittest.skip("Not supported in Flux.") - def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self): - pass - @unittest.skip("Not supported in Flux.") def test_modify_padding_mode(self): pass - @unittest.skip("Not supported in Flux.") - def test_simple_inference_with_text_denoiser_multi_adapter_block_lora(self): - pass + # TODO: skip them properly class FluxControlLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): @@ -791,21 +781,11 @@ def test_lora_unload_with_parameter_expanded_shapes_and_no_reset(self): self.assertTrue(pipe.transformer.x_embedder.weight.data.shape[1] == in_features * 2) self.assertTrue(pipe.transformer.config.in_channels == in_features * 2) - @unittest.skip("Not supported in Flux.") - def test_simple_inference_with_text_denoiser_block_scale(self): - pass - - @unittest.skip("Not supported in Flux.") - def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self): - pass - @unittest.skip("Not supported in Flux.") def test_modify_padding_mode(self): pass - @unittest.skip("Not supported in Flux.") - def test_simple_inference_with_text_denoiser_multi_adapter_block_lora(self): - pass + # TODO: skip them properly @slow diff --git a/tests/lora/test_lora_layers_hunyuanvideo.py b/tests/lora/test_lora_layers_hunyuanvideo.py index 19e31f320d0a..07de084d736f 100644 --- a/tests/lora/test_lora_layers_hunyuanvideo.py +++ b/tests/lora/test_lora_layers_hunyuanvideo.py @@ -19,6 +19,7 @@ import numpy as np import pytest import torch +from parameterized import parameterized from transformers import CLIPTextModel, CLIPTokenizer, LlamaModel, LlamaTokenizerFast from diffusers import ( @@ -150,49 +151,33 @@ def get_dummy_inputs(self, with_generator=True): return noise, input_ids, pipeline_inputs - def test_simple_inference_with_text_lora_denoiser_fused_multi(self): - super().test_simple_inference_with_text_lora_denoiser_fused_multi(expected_atol=9e-3) - - def test_simple_inference_with_text_denoiser_lora_unfused(self): - super().test_simple_inference_with_text_denoiser_lora_unfused(expected_atol=9e-3) - - # TODO(aryan): Fix the following test - @unittest.skip("This test fails with an error I haven't been able to debug yet.") - def test_simple_inference_save_pretrained(self): - pass - - @unittest.skip("Not supported in HunyuanVideo.") - def test_simple_inference_with_text_denoiser_block_scale(self): - pass - - @unittest.skip("Not supported in HunyuanVideo.") - def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self): - pass + @parameterized.expand([("simple",), ("weighted",), ("block_lora",), ("delete_adapter",)]) + def test_lora_set_adapters_scenarios(self, scenario): + expected_atol = 9e-3 + if scenario == "weighted": + expected_atol = 1e-3 + super()._test_lora_set_adapters_scenarios(scenario, expected_atol=expected_atol) + + @parameterized.expand( + [ + # Test actions on text_encoder LoRA only + ("fused", "text_encoder_only"), + ("unloaded", "text_encoder_only"), + ("save_load", "text_encoder_only"), + # Test actions on both text_encoder and denoiser LoRA + ("fused", "text_and_denoiser"), + ("unloaded", "text_and_denoiser"), + ("unfused", "text_and_denoiser"), + ("save_load", "text_and_denoiser"), + ] + ) + def test_lora_actions(self, action, components_to_add): + super()._test_lora_actions(action, components_to_add, expected_atol=9e-3) @unittest.skip("Not supported in HunyuanVideo.") def test_modify_padding_mode(self): pass - @unittest.skip("Text encoder LoRA is not supported in HunyuanVideo.") - def test_simple_inference_with_partial_text_lora(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in HunyuanVideo.") - def test_simple_inference_with_text_lora(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in HunyuanVideo.") - def test_simple_inference_with_text_lora_and_scale(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in HunyuanVideo.") - def test_simple_inference_with_text_lora_fused(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in HunyuanVideo.") - def test_simple_inference_with_text_lora_save_load(self): - pass - @nightly @require_torch_accelerator diff --git a/tests/lora/test_lora_layers_ltx_video.py b/tests/lora/test_lora_layers_ltx_video.py index 88949227cf94..20dc2d5f2c45 100644 --- a/tests/lora/test_lora_layers_ltx_video.py +++ b/tests/lora/test_lora_layers_ltx_video.py @@ -16,6 +16,7 @@ import unittest import torch +from parameterized import parameterized from transformers import AutoTokenizer, T5EncoderModel from diffusers import ( @@ -108,40 +109,26 @@ def get_dummy_inputs(self, with_generator=True): return noise, input_ids, pipeline_inputs - def test_simple_inference_with_text_lora_denoiser_fused_multi(self): - super().test_simple_inference_with_text_lora_denoiser_fused_multi(expected_atol=9e-3) - - def test_simple_inference_with_text_denoiser_lora_unfused(self): - super().test_simple_inference_with_text_denoiser_lora_unfused(expected_atol=9e-3) - - @unittest.skip("Not supported in LTXVideo.") - def test_simple_inference_with_text_denoiser_block_scale(self): - pass - - @unittest.skip("Not supported in LTXVideo.") - def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self): - pass + @parameterized.expand([("simple",), ("weighted",), ("block_lora",), ("delete_adapter",)]) + def test_lora_set_adapters_scenarios(self, scenario): + super()._test_lora_set_adapters_scenarios(scenario, expected_atol=9e-3) + + @parameterized.expand( + [ + # Test actions on text_encoder LoRA only + ("fused", "text_encoder_only"), + ("unloaded", "text_encoder_only"), + ("save_load", "text_encoder_only"), + # Test actions on both text_encoder and denoiser LoRA + ("fused", "text_and_denoiser"), + ("unloaded", "text_and_denoiser"), + ("unfused", "text_and_denoiser"), + ("save_load", "text_and_denoiser"), + ] + ) + def test_lora_actions(self, action, components_to_add): + super()._test_lora_actions(action, components_to_add, expected_atol=9e-3) @unittest.skip("Not supported in LTXVideo.") def test_modify_padding_mode(self): pass - - @unittest.skip("Text encoder LoRA is not supported in LTXVideo.") - def test_simple_inference_with_partial_text_lora(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in LTXVideo.") - def test_simple_inference_with_text_lora(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in LTXVideo.") - def test_simple_inference_with_text_lora_and_scale(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in LTXVideo.") - def test_simple_inference_with_text_lora_fused(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in LTXVideo.") - def test_simple_inference_with_text_lora_save_load(self): - pass diff --git a/tests/lora/test_lora_layers_lumina2.py b/tests/lora/test_lora_layers_lumina2.py index d7096e79b93c..63eb0e41b65e 100644 --- a/tests/lora/test_lora_layers_lumina2.py +++ b/tests/lora/test_lora_layers_lumina2.py @@ -101,38 +101,11 @@ def get_dummy_inputs(self, with_generator=True): return noise, input_ids, pipeline_inputs - @unittest.skip("Not supported in Lumina2.") - def test_simple_inference_with_text_denoiser_block_scale(self): - pass - - @unittest.skip("Not supported in Lumina2.") - def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self): - pass - @unittest.skip("Not supported in Lumina2.") def test_modify_padding_mode(self): pass @unittest.skip("Text encoder LoRA is not supported in Lumina2.") - def test_simple_inference_with_partial_text_lora(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in Lumina2.") - def test_simple_inference_with_text_lora(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in Lumina2.") - def test_simple_inference_with_text_lora_and_scale(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in Lumina2.") - def test_simple_inference_with_text_lora_fused(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in Lumina2.") - def test_simple_inference_with_text_lora_save_load(self): - pass - @skip_mps @pytest.mark.xfail( condition=torch.device(torch_device).type == "cpu" and is_torch_version(">=", "2.5"), @@ -141,11 +114,9 @@ def test_simple_inference_with_text_lora_save_load(self): ) def test_lora_fuse_nan(self): for scheduler_cls in self.scheduler_classes: - components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls) - pipe = self.pipeline_class(**components) - pipe = pipe.to(torch_device) - pipe.set_progress_bar_config(disable=None) - _, _, inputs = self.get_dummy_inputs(with_generator=False) + pipe, inputs, output_no_lora, text_lora_config, denoiser_lora_config = ( + self._setup_pipeline_and_get_base_output(scheduler_cls) + ) if "text_encoder" in self.pipeline_class._lora_loadable_modules: pipe.text_encoder.add_adapter(text_lora_config, "adapter-1") diff --git a/tests/lora/test_lora_layers_mochi.py b/tests/lora/test_lora_layers_mochi.py index 501a4b35f48e..d6dd59b4b182 100644 --- a/tests/lora/test_lora_layers_mochi.py +++ b/tests/lora/test_lora_layers_mochi.py @@ -16,14 +16,11 @@ import unittest import torch +from parameterized import parameterized from transformers import AutoTokenizer, T5EncoderModel from diffusers import AutoencoderKLMochi, FlowMatchEulerDiscreteScheduler, MochiPipeline, MochiTransformer3DModel -from diffusers.utils.testing_utils import ( - floats_tensor, - require_peft_backend, - skip_mps, -) +from diffusers.utils.testing_utils import floats_tensor, require_peft_backend, skip_mps sys.path.append(".") @@ -99,44 +96,28 @@ def get_dummy_inputs(self, with_generator=True): return noise, input_ids, pipeline_inputs - def test_simple_inference_with_text_lora_denoiser_fused_multi(self): - super().test_simple_inference_with_text_lora_denoiser_fused_multi(expected_atol=9e-3) - - def test_simple_inference_with_text_denoiser_lora_unfused(self): - super().test_simple_inference_with_text_denoiser_lora_unfused(expected_atol=9e-3) - - @unittest.skip("Not supported in Mochi.") - def test_simple_inference_with_text_denoiser_block_scale(self): - pass - - @unittest.skip("Not supported in Mochi.") - def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self): - pass + @parameterized.expand([("simple",), ("weighted",), ("block_lora",), ("delete_adapter",)]) + def test_lora_set_adapters_scenarios(self, scenario): + super()._test_lora_set_adapters_scenarios(scenario, expected_atol=9e-3) + + @parameterized.expand( + [ + # Test actions on text_encoder LoRA only + ("fused", "text_encoder_only"), + ("unloaded", "text_encoder_only"), + ("save_load", "text_encoder_only"), + # Test actions on both text_encoder and denoiser LoRA + ("fused", "text_and_denoiser"), + ("unloaded", "text_and_denoiser"), + ("unfused", "text_and_denoiser"), + ("save_load", "text_and_denoiser"), + ] + ) + def test_lora_actions(self, action, components_to_add): + super()._test_lora_actions(action, components_to_add, expected_atol=9e-3) @unittest.skip("Not supported in Mochi.") def test_modify_padding_mode(self): pass - @unittest.skip("Text encoder LoRA is not supported in Mochi.") - def test_simple_inference_with_partial_text_lora(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in Mochi.") - def test_simple_inference_with_text_lora(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in Mochi.") - def test_simple_inference_with_text_lora_and_scale(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in Mochi.") - def test_simple_inference_with_text_lora_fused(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in Mochi.") - def test_simple_inference_with_text_lora_save_load(self): - pass - - @unittest.skip("Not supported in CogVideoX.") - def test_simple_inference_with_text_denoiser_multi_adapter_block_lora(self): - pass + # TODO: skip them properly diff --git a/tests/lora/test_lora_layers_sana.py b/tests/lora/test_lora_layers_sana.py index 24beb46b95ff..5606d06a9945 100644 --- a/tests/lora/test_lora_layers_sana.py +++ b/tests/lora/test_lora_layers_sana.py @@ -108,31 +108,3 @@ def get_dummy_inputs(self, with_generator=True): @unittest.skip("Not supported in SANA.") def test_modify_padding_mode(self): pass - - @unittest.skip("Not supported in SANA.") - def test_simple_inference_with_text_denoiser_block_scale(self): - pass - - @unittest.skip("Not supported in SANA.") - def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in SANA.") - def test_simple_inference_with_partial_text_lora(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in SANA.") - def test_simple_inference_with_text_lora(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in SANA.") - def test_simple_inference_with_text_lora_and_scale(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in SANA.") - def test_simple_inference_with_text_lora_fused(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in SANA.") - def test_simple_inference_with_text_lora_save_load(self): - pass diff --git a/tests/lora/test_lora_layers_sd.py b/tests/lora/test_lora_layers_sd.py index 1c5a9b00e9da..32d48e7ed90b 100644 --- a/tests/lora/test_lora_layers_sd.py +++ b/tests/lora/test_lora_layers_sd.py @@ -20,6 +20,7 @@ import torch import torch.nn as nn from huggingface_hub import hf_hub_download +from parameterized import parameterized from safetensors.torch import load_file from transformers import CLIPTextModel, CLIPTokenizer @@ -208,6 +209,19 @@ def test_integration_move_lora_dora_cpu(self): if "lora_" in name: self.assertNotEqual(param.device, torch.device("cpu")) + @parameterized.expand([("simple",), ("weighted",), ("block_lora",), ("delete_adapter",)]) + def test_lora_set_adapters_scenarios(self, scenario): + if torch.cuda.is_available(): + expected_atol = 9e-2 + expected_rtol = 9e-2 + else: + expected_atol = 1e-3 + expected_rtol = 1e-3 + + super()._test_lora_set_adapters_scenarios( + scenario=scenario, expected_atol=expected_atol, expected_rtol=expected_rtol + ) + @slow @require_torch_accelerator def test_integration_set_lora_device_different_target_layers(self): diff --git a/tests/lora/test_lora_layers_sd3.py b/tests/lora/test_lora_layers_sd3.py index 8a8f2a676df1..344fa0aedfa5 100644 --- a/tests/lora/test_lora_layers_sd3.py +++ b/tests/lora/test_lora_layers_sd3.py @@ -114,17 +114,7 @@ def test_sd3_lora(self): lora_filename = "lora_peft_format.safetensors" pipe.load_lora_weights(lora_model_id, weight_name=lora_filename) - @unittest.skip("Not supported in SD3.") - def test_simple_inference_with_text_denoiser_block_scale(self): - pass - - @unittest.skip("Not supported in SD3.") - def test_simple_inference_with_text_denoiser_multi_adapter_block_lora(self): - pass - - @unittest.skip("Not supported in SD3.") - def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self): - pass + # TODO: skip them properly @unittest.skip("Not supported in SD3.") def test_modify_padding_mode(self): diff --git a/tests/lora/test_lora_layers_sdxl.py b/tests/lora/test_lora_layers_sdxl.py index 267650056aad..eee1502b5849 100644 --- a/tests/lora/test_lora_layers_sdxl.py +++ b/tests/lora/test_lora_layers_sdxl.py @@ -22,6 +22,7 @@ import numpy as np import torch from packaging import version +from parameterized import parameterized from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer from diffusers import ( @@ -117,7 +118,20 @@ def tearDown(self): def test_multiple_wrong_adapter_name_raises_error(self): super().test_multiple_wrong_adapter_name_raises_error() - def test_simple_inference_with_text_denoiser_lora_unfused(self): + @parameterized.expand( + [ + # Test actions on text_encoder LoRA only + ("fused", "text_encoder_only"), + ("unloaded", "text_encoder_only"), + ("save_load", "text_encoder_only"), + # Test actions on both text_encoder and denoiser LoRA + ("fused", "text_and_denoiser"), + ("unloaded", "text_and_denoiser"), + ("unfused", "text_and_denoiser"), + ("save_load", "text_and_denoiser"), + ] + ) + def test_lora_actions(self, action, components_to_add): if torch.cuda.is_available(): expected_atol = 9e-2 expected_rtol = 9e-2 @@ -125,11 +139,10 @@ def test_simple_inference_with_text_denoiser_lora_unfused(self): expected_atol = 1e-3 expected_rtol = 1e-3 - super().test_simple_inference_with_text_denoiser_lora_unfused( - expected_atol=expected_atol, expected_rtol=expected_rtol - ) + super()._test_lora_actions(action, components_to_add, expected_atol=expected_atol, expected_rtol=expected_rtol) - def test_simple_inference_with_text_lora_denoiser_fused_multi(self): + @parameterized.expand([("simple",), ("weighted",), ("block_lora",), ("delete_adapter",), ("fused_multi",)]) + def test_lora_set_adapters_scenarios(self, scenario): if torch.cuda.is_available(): expected_atol = 9e-2 expected_rtol = 9e-2 @@ -137,8 +150,8 @@ def test_simple_inference_with_text_lora_denoiser_fused_multi(self): expected_atol = 1e-3 expected_rtol = 1e-3 - super().test_simple_inference_with_text_lora_denoiser_fused_multi( - expected_atol=expected_atol, expected_rtol=expected_rtol + super()._test_lora_set_adapters_scenarios( + scenario=scenario, expected_atol=expected_atol, expected_rtol=expected_rtol ) def test_lora_scale_kwargs_match_fusion(self): diff --git a/tests/lora/test_lora_layers_wan.py b/tests/lora/test_lora_layers_wan.py index fe26a56e77cf..9933ad7ebd4b 100644 --- a/tests/lora/test_lora_layers_wan.py +++ b/tests/lora/test_lora_layers_wan.py @@ -16,19 +16,11 @@ import unittest import torch +from parameterized import parameterized from transformers import AutoTokenizer, T5EncoderModel -from diffusers import ( - AutoencoderKLWan, - FlowMatchEulerDiscreteScheduler, - WanPipeline, - WanTransformer3DModel, -) -from diffusers.utils.testing_utils import ( - floats_tensor, - require_peft_backend, - skip_mps, -) +from diffusers import AutoencoderKLWan, FlowMatchEulerDiscreteScheduler, WanPipeline, WanTransformer3DModel +from diffusers.utils.testing_utils import floats_tensor, require_peft_backend, skip_mps sys.path.append(".") @@ -104,40 +96,26 @@ def get_dummy_inputs(self, with_generator=True): return noise, input_ids, pipeline_inputs - def test_simple_inference_with_text_lora_denoiser_fused_multi(self): - super().test_simple_inference_with_text_lora_denoiser_fused_multi(expected_atol=9e-3) - - def test_simple_inference_with_text_denoiser_lora_unfused(self): - super().test_simple_inference_with_text_denoiser_lora_unfused(expected_atol=9e-3) - - @unittest.skip("Not supported in Wan.") - def test_simple_inference_with_text_denoiser_block_scale(self): - pass - - @unittest.skip("Not supported in Wan.") - def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self): - pass + @parameterized.expand([("simple",), ("weighted",), ("block_lora",), ("delete_adapter",)]) + def test_lora_set_adapters_scenarios(self, scenario): + super()._test_lora_set_adapters_scenarios(scenario, expected_atol=9e-3) + + @parameterized.expand( + [ + # Test actions on text_encoder LoRA only + ("fused", "text_encoder_only"), + ("unloaded", "text_encoder_only"), + ("save_load", "text_encoder_only"), + # Test actions on both text_encoder and denoiser LoRA + ("fused", "text_and_denoiser"), + ("unloaded", "text_and_denoiser"), + ("unfused", "text_and_denoiser"), + ("save_load", "text_and_denoiser"), + ] + ) + def test_lora_actions(self, action, components_to_add): + super()._test_lora_actions(action, components_to_add, expected_atol=9e-3) @unittest.skip("Not supported in Wan.") def test_modify_padding_mode(self): pass - - @unittest.skip("Text encoder LoRA is not supported in Wan.") - def test_simple_inference_with_partial_text_lora(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in Wan.") - def test_simple_inference_with_text_lora(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in Wan.") - def test_simple_inference_with_text_lora_and_scale(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in Wan.") - def test_simple_inference_with_text_lora_fused(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in Wan.") - def test_simple_inference_with_text_lora_save_load(self): - pass diff --git a/tests/lora/test_lora_layers_wanvace.py b/tests/lora/test_lora_layers_wanvace.py index a7eb74080499..dfa04505743c 100644 --- a/tests/lora/test_lora_layers_wanvace.py +++ b/tests/lora/test_lora_layers_wanvace.py @@ -21,19 +21,13 @@ import pytest import safetensors.torch import torch +from parameterized import parameterized from PIL import Image from transformers import AutoTokenizer, T5EncoderModel from diffusers import AutoencoderKLWan, FlowMatchEulerDiscreteScheduler, WanVACEPipeline, WanVACETransformer3DModel from diffusers.utils.import_utils import is_peft_available -from diffusers.utils.testing_utils import ( - floats_tensor, - is_flaky, - require_peft_backend, - require_peft_version_greater, - skip_mps, - torch_device, -) +from diffusers.utils.testing_utils import floats_tensor, require_peft_backend, skip_mps, torch_device if is_peft_available(): @@ -121,44 +115,30 @@ def get_dummy_inputs(self, with_generator=True): return noise, input_ids, pipeline_inputs - def test_simple_inference_with_text_lora_denoiser_fused_multi(self): - super().test_simple_inference_with_text_lora_denoiser_fused_multi(expected_atol=9e-3) - - def test_simple_inference_with_text_denoiser_lora_unfused(self): - super().test_simple_inference_with_text_denoiser_lora_unfused(expected_atol=9e-3) - - @unittest.skip("Not supported in Wan VACE.") - def test_simple_inference_with_text_denoiser_block_scale(self): - pass - - @unittest.skip("Not supported in Wan VACE.") - def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self): - pass + @parameterized.expand([("simple",), ("weighted",), ("block_lora",), ("delete_adapter",)]) + def test_lora_set_adapters_scenarios(self, scenario): + super()._test_lora_set_adapters_scenarios(scenario, expected_atol=9e-3) + + @parameterized.expand( + [ + # Test actions on text_encoder LoRA only + ("fused", "text_encoder_only"), + ("unloaded", "text_encoder_only"), + ("save_load", "text_encoder_only"), + # Test actions on both text_encoder and denoiser LoRA + ("fused", "text_and_denoiser"), + ("unloaded", "text_and_denoiser"), + ("unfused", "text_and_denoiser"), + ("save_load", "text_and_denoiser"), + ] + ) + def test_lora_actions(self, action, components_to_add): + super()._test_lora_actions(action, components_to_add, expected_atol=9e-3) @unittest.skip("Not supported in Wan VACE.") def test_modify_padding_mode(self): pass - @unittest.skip("Text encoder LoRA is not supported in Wan VACE.") - def test_simple_inference_with_partial_text_lora(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in Wan VACE.") - def test_simple_inference_with_text_lora(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in Wan VACE.") - def test_simple_inference_with_text_lora_and_scale(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in Wan VACE.") - def test_simple_inference_with_text_lora_fused(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in Wan VACE.") - def test_simple_inference_with_text_lora_save_load(self): - pass - @pytest.mark.xfail( condition=True, reason="RuntimeError: Input type (float) and bias type (c10::BFloat16) should be the same", @@ -167,7 +147,6 @@ def test_simple_inference_with_text_lora_save_load(self): def test_layerwise_casting_inference_denoiser(self): super().test_layerwise_casting_inference_denoiser() - @require_peft_version_greater("0.13.2") def test_lora_exclude_modules_wanvace(self): scheduler_cls = self.scheduler_classes[0] exclude_module_name = "vace_blocks.0.proj_out" @@ -216,7 +195,3 @@ def test_lora_exclude_modules_wanvace(self): np.allclose(output_lora_exclude_modules, output_lora_pretrained, atol=1e-3, rtol=1e-3), "Lora outputs should match.", ) - - @is_flaky - def test_simple_inference_with_text_denoiser_lora_and_scale(self): - super().test_simple_inference_with_text_denoiser_lora_and_scale() diff --git a/tests/lora/utils.py b/tests/lora/utils.py index 91ca188137e7..0633c8574f8f 100644 --- a/tests/lora/utils.py +++ b/tests/lora/utils.py @@ -17,7 +17,6 @@ import os import re import tempfile -import unittest from itertools import product import numpy as np @@ -39,7 +38,6 @@ floats_tensor, is_torch_version, require_peft_backend, - require_peft_version_greater, require_torch_accelerator, require_transformers_version_greater, skip_mps, @@ -129,1073 +127,197 @@ class PeftLoraLoaderMixinTests: text_encoder_target_modules = ["q_proj", "k_proj", "v_proj", "out_proj"] denoiser_target_modules = ["to_q", "to_k", "to_v", "to_out.0"] - def get_dummy_components(self, scheduler_cls=None, use_dora=False, lora_alpha=None): - if self.unet_kwargs and self.transformer_kwargs: - raise ValueError("Both `unet_kwargs` and `transformer_kwargs` cannot be specified.") - if self.has_two_text_encoders and self.has_three_text_encoders: - raise ValueError("Both `has_two_text_encoders` and `has_three_text_encoders` cannot be True.") - - scheduler_cls = self.scheduler_cls if scheduler_cls is None else scheduler_cls - rank = 4 - lora_alpha = rank if lora_alpha is None else lora_alpha - - torch.manual_seed(0) - if self.unet_kwargs is not None: - unet = UNet2DConditionModel(**self.unet_kwargs) - else: - transformer = self.transformer_cls(**self.transformer_kwargs) - - scheduler = scheduler_cls(**self.scheduler_kwargs) - - torch.manual_seed(0) - vae = self.vae_cls(**self.vae_kwargs) - - text_encoder = self.text_encoder_cls.from_pretrained( - self.text_encoder_id, subfolder=self.text_encoder_subfolder - ) - tokenizer = self.tokenizer_cls.from_pretrained(self.tokenizer_id, subfolder=self.tokenizer_subfolder) - - if self.text_encoder_2_cls is not None: - text_encoder_2 = self.text_encoder_2_cls.from_pretrained( - self.text_encoder_2_id, subfolder=self.text_encoder_2_subfolder - ) - tokenizer_2 = self.tokenizer_2_cls.from_pretrained( - self.tokenizer_2_id, subfolder=self.tokenizer_2_subfolder - ) - - if self.text_encoder_3_cls is not None: - text_encoder_3 = self.text_encoder_3_cls.from_pretrained( - self.text_encoder_3_id, subfolder=self.text_encoder_3_subfolder - ) - tokenizer_3 = self.tokenizer_3_cls.from_pretrained( - self.tokenizer_3_id, subfolder=self.tokenizer_3_subfolder - ) - - text_lora_config = LoraConfig( - r=rank, - lora_alpha=lora_alpha, - target_modules=self.text_encoder_target_modules, - init_lora_weights=False, - use_dora=use_dora, - ) - - denoiser_lora_config = LoraConfig( - r=rank, - lora_alpha=lora_alpha, - target_modules=self.denoiser_target_modules, - init_lora_weights=False, - use_dora=use_dora, - ) - - pipeline_components = { - "scheduler": scheduler, - "vae": vae, - "text_encoder": text_encoder, - "tokenizer": tokenizer, - } - # Denoiser - if self.unet_kwargs is not None: - pipeline_components.update({"unet": unet}) - elif self.transformer_kwargs is not None: - pipeline_components.update({"transformer": transformer}) - - # Remaining text encoders. - if self.text_encoder_2_cls is not None: - pipeline_components.update({"tokenizer_2": tokenizer_2, "text_encoder_2": text_encoder_2}) - if self.text_encoder_3_cls is not None: - pipeline_components.update({"tokenizer_3": tokenizer_3, "text_encoder_3": text_encoder_3}) - - # Remaining stuff - init_params = inspect.signature(self.pipeline_class.__init__).parameters - if "safety_checker" in init_params: - pipeline_components.update({"safety_checker": None}) - if "feature_extractor" in init_params: - pipeline_components.update({"feature_extractor": None}) - if "image_encoder" in init_params: - pipeline_components.update({"image_encoder": None}) - - return pipeline_components, text_lora_config, denoiser_lora_config - - @property - def output_shape(self): - raise NotImplementedError - - def get_dummy_inputs(self, with_generator=True): - batch_size = 1 - sequence_length = 10 - num_channels = 4 - sizes = (32, 32) - - generator = torch.manual_seed(0) - noise = floats_tensor((batch_size, num_channels) + sizes) - input_ids = torch.randint(1, sequence_length, size=(batch_size, sequence_length), generator=generator) - - pipeline_inputs = { - "prompt": "A painting of a squirrel eating a burger", - "num_inference_steps": 5, - "guidance_scale": 6.0, - "output_type": "np", - } - if with_generator: - pipeline_inputs.update({"generator": generator}) - - return noise, input_ids, pipeline_inputs - - # Copied from: https://colab.research.google.com/gist/sayakpaul/df2ef6e1ae6d8c10a49d859883b10860/scratchpad.ipynb - def get_dummy_tokens(self): - max_seq_length = 77 - - inputs = torch.randint(2, 56, size=(1, max_seq_length), generator=torch.manual_seed(0)) - - prepared_inputs = {} - prepared_inputs["input_ids"] = inputs - return prepared_inputs - - def _get_lora_state_dicts(self, modules_to_save): - state_dicts = {} - for module_name, module in modules_to_save.items(): - if module is not None: - state_dicts[f"{module_name}_lora_layers"] = get_peft_model_state_dict(module) - return state_dicts - - def _get_lora_adapter_metadata(self, modules_to_save): - metadatas = {} - for module_name, module in modules_to_save.items(): - if module is not None: - metadatas[f"{module_name}_lora_adapter_metadata"] = module.peft_config["default"].to_dict() - return metadatas - - def _get_modules_to_save(self, pipe, has_denoiser=False): - modules_to_save = {} - lora_loadable_modules = self.pipeline_class._lora_loadable_modules - - if ( - "text_encoder" in lora_loadable_modules - and hasattr(pipe, "text_encoder") - and getattr(pipe.text_encoder, "peft_config", None) is not None - ): - modules_to_save["text_encoder"] = pipe.text_encoder - - if ( - "text_encoder_2" in lora_loadable_modules - and hasattr(pipe, "text_encoder_2") - and getattr(pipe.text_encoder_2, "peft_config", None) is not None - ): - modules_to_save["text_encoder_2"] = pipe.text_encoder_2 - - if has_denoiser: - if "unet" in lora_loadable_modules and hasattr(pipe, "unet"): - modules_to_save["unet"] = pipe.unet - - if "transformer" in lora_loadable_modules and hasattr(pipe, "transformer"): - modules_to_save["transformer"] = pipe.transformer - - return modules_to_save - - def _get_exclude_modules(self, pipe): - from diffusers.utils.peft_utils import _derive_exclude_modules - - modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True) - denoiser = "unet" if self.unet_kwargs is not None else "transformer" - modules_to_save = {k: v for k, v in modules_to_save.items() if k == denoiser} - denoiser_lora_state_dict = self._get_lora_state_dicts(modules_to_save)[f"{denoiser}_lora_layers"] - pipe.unload_lora_weights() - denoiser_state_dict = pipe.unet.state_dict() if self.unet_kwargs is not None else pipe.transformer.state_dict() - exclude_modules = _derive_exclude_modules( - denoiser_state_dict, denoiser_lora_state_dict, adapter_name="default" - ) - return exclude_modules - - def add_adapters_to_pipeline(self, pipe, text_lora_config=None, denoiser_lora_config=None, adapter_name="default"): - if text_lora_config is not None: - if "text_encoder" in self.pipeline_class._lora_loadable_modules: - pipe.text_encoder.add_adapter(text_lora_config, adapter_name=adapter_name) - self.assertTrue( - check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder" - ) - - if denoiser_lora_config is not None: - denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet - denoiser.add_adapter(denoiser_lora_config, adapter_name=adapter_name) - self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.") - else: - denoiser = None - - if text_lora_config is not None and self.has_two_text_encoders or self.has_three_text_encoders: - if "text_encoder_2" in self.pipeline_class._lora_loadable_modules: - pipe.text_encoder_2.add_adapter(text_lora_config, adapter_name=adapter_name) - self.assertTrue( - check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" - ) - return pipe, denoiser - - def test_simple_inference(self): - """ - Tests a simple inference and makes sure it works as expected - """ - for scheduler_cls in self.scheduler_classes: - components, text_lora_config, _ = self.get_dummy_components(scheduler_cls) - pipe = self.pipeline_class(**components) - pipe = pipe.to(torch_device) - pipe.set_progress_bar_config(disable=None) - - _, _, inputs = self.get_dummy_inputs() - output_no_lora = pipe(**inputs)[0] - self.assertTrue(output_no_lora.shape == self.output_shape) - - def test_simple_inference_with_text_lora(self): - """ - Tests a simple inference with lora attached on the text encoder - and makes sure it works as expected - """ - for scheduler_cls in self.scheduler_classes: - components, text_lora_config, _ = self.get_dummy_components(scheduler_cls) - pipe = self.pipeline_class(**components) - pipe = pipe.to(torch_device) - pipe.set_progress_bar_config(disable=None) - _, _, inputs = self.get_dummy_inputs(with_generator=False) - - output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] - self.assertTrue(output_no_lora.shape == self.output_shape) - - pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None) - - output_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] - self.assertTrue( - not np.allclose(output_lora, output_no_lora, atol=1e-3, rtol=1e-3), "Lora should change the output" - ) - - @require_peft_version_greater("0.13.1") def test_low_cpu_mem_usage_with_injection(self): """Tests if we can inject LoRA state dict with low_cpu_mem_usage.""" for scheduler_cls in self.scheduler_classes: - components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls) - pipe = self.pipeline_class(**components) - pipe = pipe.to(torch_device) - pipe.set_progress_bar_config(disable=None) - + pipe, inputs, _, text_lora_config, denoiser_lora_config = self._setup_pipeline_and_get_base_output( + scheduler_cls + ) if "text_encoder" in self.pipeline_class._lora_loadable_modules: inject_adapter_in_model(text_lora_config, pipe.text_encoder, low_cpu_mem_usage=True) self.assertTrue( check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder." ) self.assertTrue( - "meta" in {p.device.type for p in pipe.text_encoder.parameters()}, - "The LoRA params should be on 'meta' device.", - ) - - te_state_dict = initialize_dummy_state_dict(get_peft_model_state_dict(pipe.text_encoder)) - set_peft_model_state_dict(pipe.text_encoder, te_state_dict, low_cpu_mem_usage=True) - self.assertTrue( - "meta" not in {p.device.type for p in pipe.text_encoder.parameters()}, - "No param should be on 'meta' device.", - ) - - denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet - inject_adapter_in_model(denoiser_lora_config, denoiser, low_cpu_mem_usage=True) - self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.") - self.assertTrue( - "meta" in {p.device.type for p in denoiser.parameters()}, "The LoRA params should be on 'meta' device." - ) - - denoiser_state_dict = initialize_dummy_state_dict(get_peft_model_state_dict(denoiser)) - set_peft_model_state_dict(denoiser, denoiser_state_dict, low_cpu_mem_usage=True) - self.assertTrue( - "meta" not in {p.device.type for p in denoiser.parameters()}, "No param should be on 'meta' device." - ) - - if self.has_two_text_encoders or self.has_three_text_encoders: - if "text_encoder_2" in self.pipeline_class._lora_loadable_modules: - inject_adapter_in_model(text_lora_config, pipe.text_encoder_2, low_cpu_mem_usage=True) - self.assertTrue( - check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" - ) - self.assertTrue( - "meta" in {p.device.type for p in pipe.text_encoder_2.parameters()}, - "The LoRA params should be on 'meta' device.", - ) - - te2_state_dict = initialize_dummy_state_dict(get_peft_model_state_dict(pipe.text_encoder_2)) - set_peft_model_state_dict(pipe.text_encoder_2, te2_state_dict, low_cpu_mem_usage=True) - self.assertTrue( - "meta" not in {p.device.type for p in pipe.text_encoder_2.parameters()}, - "No param should be on 'meta' device.", - ) - - _, _, inputs = self.get_dummy_inputs() - output_lora = pipe(**inputs)[0] - self.assertTrue(output_lora.shape == self.output_shape) - - @require_peft_version_greater("0.13.1") - @require_transformers_version_greater("4.45.2") - def test_low_cpu_mem_usage_with_loading(self): - """Tests if we can load LoRA state dict with low_cpu_mem_usage.""" - - for scheduler_cls in self.scheduler_classes: - components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls) - pipe = self.pipeline_class(**components) - pipe = pipe.to(torch_device) - pipe.set_progress_bar_config(disable=None) - _, _, inputs = self.get_dummy_inputs(with_generator=False) - - output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] - self.assertTrue(output_no_lora.shape == self.output_shape) - - pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config) - - images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] - - with tempfile.TemporaryDirectory() as tmpdirname: - modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True) - lora_state_dicts = self._get_lora_state_dicts(modules_to_save) - self.pipeline_class.save_lora_weights( - save_directory=tmpdirname, safe_serialization=False, **lora_state_dicts - ) - - self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin"))) - pipe.unload_lora_weights() - pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.bin"), low_cpu_mem_usage=False) - - for module_name, module in modules_to_save.items(): - self.assertTrue(check_if_lora_correctly_set(module), f"Lora not correctly set in {module_name}") - - images_lora_from_pretrained = pipe(**inputs, generator=torch.manual_seed(0))[0] - self.assertTrue( - np.allclose(images_lora, images_lora_from_pretrained, atol=1e-3, rtol=1e-3), - "Loading from saved checkpoints should give same results.", - ) - - # Now, check for `low_cpu_mem_usage.` - pipe.unload_lora_weights() - pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.bin"), low_cpu_mem_usage=True) - - for module_name, module in modules_to_save.items(): - self.assertTrue(check_if_lora_correctly_set(module), f"Lora not correctly set in {module_name}") - - images_lora_from_pretrained_low_cpu = pipe(**inputs, generator=torch.manual_seed(0))[0] - self.assertTrue( - np.allclose( - images_lora_from_pretrained_low_cpu, images_lora_from_pretrained, atol=1e-3, rtol=1e-3 - ), - "Loading from saved checkpoints with `low_cpu_mem_usage` should give same results.", - ) - - def test_simple_inference_with_text_lora_and_scale(self): - """ - Tests a simple inference with lora attached on the text encoder + scale argument - and makes sure it works as expected - """ - attention_kwargs_name = determine_attention_kwargs_name(self.pipeline_class) - - for scheduler_cls in self.scheduler_classes: - components, text_lora_config, _ = self.get_dummy_components(scheduler_cls) - pipe = self.pipeline_class(**components) - pipe = pipe.to(torch_device) - pipe.set_progress_bar_config(disable=None) - _, _, inputs = self.get_dummy_inputs(with_generator=False) - - output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] - self.assertTrue(output_no_lora.shape == self.output_shape) - - pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None) - - output_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] - self.assertTrue( - not np.allclose(output_lora, output_no_lora, atol=1e-3, rtol=1e-3), "Lora should change the output" - ) - - attention_kwargs = {attention_kwargs_name: {"scale": 0.5}} - output_lora_scale = pipe(**inputs, generator=torch.manual_seed(0), **attention_kwargs)[0] - - self.assertTrue( - not np.allclose(output_lora, output_lora_scale, atol=1e-3, rtol=1e-3), - "Lora + scale should change the output", - ) - - attention_kwargs = {attention_kwargs_name: {"scale": 0.0}} - output_lora_0_scale = pipe(**inputs, generator=torch.manual_seed(0), **attention_kwargs)[0] - - self.assertTrue( - np.allclose(output_no_lora, output_lora_0_scale, atol=1e-3, rtol=1e-3), - "Lora + 0 scale should lead to same result as no LoRA", - ) - - def test_simple_inference_with_text_lora_fused(self): - """ - Tests a simple inference with lora attached into text encoder + fuses the lora weights into base model - and makes sure it works as expected - """ - for scheduler_cls in self.scheduler_classes: - components, text_lora_config, _ = self.get_dummy_components(scheduler_cls) - pipe = self.pipeline_class(**components) - pipe = pipe.to(torch_device) - pipe.set_progress_bar_config(disable=None) - _, _, inputs = self.get_dummy_inputs(with_generator=False) - - output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] - self.assertTrue(output_no_lora.shape == self.output_shape) - - pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None) - - pipe.fuse_lora() - # Fusing should still keep the LoRA layers - self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder") - - if self.has_two_text_encoders or self.has_three_text_encoders: - if "text_encoder_2" in self.pipeline_class._lora_loadable_modules: - self.assertTrue( - check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" - ) - - ouput_fused = pipe(**inputs, generator=torch.manual_seed(0))[0] - self.assertFalse( - np.allclose(ouput_fused, output_no_lora, atol=1e-3, rtol=1e-3), "Fused lora should change the output" - ) - - def test_simple_inference_with_text_lora_unloaded(self): - """ - Tests a simple inference with lora attached to text encoder, then unloads the lora weights - and makes sure it works as expected - """ - for scheduler_cls in self.scheduler_classes: - components, text_lora_config, _ = self.get_dummy_components(scheduler_cls) - pipe = self.pipeline_class(**components) - pipe = pipe.to(torch_device) - pipe.set_progress_bar_config(disable=None) - _, _, inputs = self.get_dummy_inputs(with_generator=False) - - output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] - self.assertTrue(output_no_lora.shape == self.output_shape) - - pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None) - - pipe.unload_lora_weights() - # unloading should remove the LoRA layers - self.assertFalse( - check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly unloaded in text encoder" - ) - - if self.has_two_text_encoders or self.has_three_text_encoders: - if "text_encoder_2" in self.pipeline_class._lora_loadable_modules: - self.assertFalse( - check_if_lora_correctly_set(pipe.text_encoder_2), - "Lora not correctly unloaded in text encoder 2", - ) - - ouput_unloaded = pipe(**inputs, generator=torch.manual_seed(0))[0] - self.assertTrue( - np.allclose(ouput_unloaded, output_no_lora, atol=1e-3, rtol=1e-3), - "Fused lora should change the output", - ) - - def test_simple_inference_with_text_lora_save_load(self): - """ - Tests a simple usecase where users could use saving utilities for LoRA. - """ - for scheduler_cls in self.scheduler_classes: - components, text_lora_config, _ = self.get_dummy_components(scheduler_cls) - pipe = self.pipeline_class(**components) - pipe = pipe.to(torch_device) - pipe.set_progress_bar_config(disable=None) - _, _, inputs = self.get_dummy_inputs(with_generator=False) - - output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] - self.assertTrue(output_no_lora.shape == self.output_shape) - - pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None) - - images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] - - with tempfile.TemporaryDirectory() as tmpdirname: - modules_to_save = self._get_modules_to_save(pipe) - lora_state_dicts = self._get_lora_state_dicts(modules_to_save) - - self.pipeline_class.save_lora_weights( - save_directory=tmpdirname, safe_serialization=False, **lora_state_dicts - ) - - self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin"))) - pipe.unload_lora_weights() - pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.bin")) - - for module_name, module in modules_to_save.items(): - self.assertTrue(check_if_lora_correctly_set(module), f"Lora not correctly set in {module_name}") - - images_lora_from_pretrained = pipe(**inputs, generator=torch.manual_seed(0))[0] - - self.assertTrue( - np.allclose(images_lora, images_lora_from_pretrained, atol=1e-3, rtol=1e-3), - "Loading from saved checkpoints should give same results.", - ) - - def test_simple_inference_with_partial_text_lora(self): - """ - Tests a simple inference with lora attached on the text encoder - with different ranks and some adapters removed - and makes sure it works as expected - """ - for scheduler_cls in self.scheduler_classes: - components, _, _ = self.get_dummy_components(scheduler_cls) - # Verify `StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder` handles different ranks per module (PR#8324). - text_lora_config = LoraConfig( - r=4, - rank_pattern={self.text_encoder_target_modules[i]: i + 1 for i in range(3)}, - lora_alpha=4, - target_modules=self.text_encoder_target_modules, - init_lora_weights=False, - use_dora=False, - ) - pipe = self.pipeline_class(**components) - pipe = pipe.to(torch_device) - pipe.set_progress_bar_config(disable=None) - _, _, inputs = self.get_dummy_inputs(with_generator=False) - - output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] - self.assertTrue(output_no_lora.shape == self.output_shape) - - pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None) - - state_dict = {} - if "text_encoder" in self.pipeline_class._lora_loadable_modules: - # Gather the state dict for the PEFT model, excluding `layers.4`, to ensure `load_lora_into_text_encoder` - # supports missing layers (PR#8324). - state_dict = { - f"text_encoder.{module_name}": param - for module_name, param in get_peft_model_state_dict(pipe.text_encoder).items() - if "text_model.encoder.layers.4" not in module_name - } - - if self.has_two_text_encoders or self.has_three_text_encoders: - if "text_encoder_2" in self.pipeline_class._lora_loadable_modules: - state_dict.update( - { - f"text_encoder_2.{module_name}": param - for module_name, param in get_peft_model_state_dict(pipe.text_encoder_2).items() - if "text_model.encoder.layers.4" not in module_name - } - ) - - output_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] - self.assertTrue( - not np.allclose(output_lora, output_no_lora, atol=1e-3, rtol=1e-3), "Lora should change the output" - ) - - # Unload lora and load it back using the pipe.load_lora_weights machinery - pipe.unload_lora_weights() - pipe.load_lora_weights(state_dict) - - output_partial_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] - self.assertTrue( - not np.allclose(output_partial_lora, output_lora, atol=1e-3, rtol=1e-3), - "Removing adapters should change the output", - ) - - def test_simple_inference_save_pretrained_with_text_lora(self): - """ - Tests a simple usecase where users could use saving utilities for LoRA through save_pretrained - """ - for scheduler_cls in self.scheduler_classes: - components, text_lora_config, _ = self.get_dummy_components(scheduler_cls) - pipe = self.pipeline_class(**components) - pipe = pipe.to(torch_device) - pipe.set_progress_bar_config(disable=None) - _, _, inputs = self.get_dummy_inputs(with_generator=False) - - output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] - self.assertTrue(output_no_lora.shape == self.output_shape) - - pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None) - images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] - - with tempfile.TemporaryDirectory() as tmpdirname: - pipe.save_pretrained(tmpdirname) - - pipe_from_pretrained = self.pipeline_class.from_pretrained(tmpdirname) - pipe_from_pretrained.to(torch_device) - - if "text_encoder" in self.pipeline_class._lora_loadable_modules: - self.assertTrue( - check_if_lora_correctly_set(pipe_from_pretrained.text_encoder), - "Lora not correctly set in text encoder", - ) - - if self.has_two_text_encoders or self.has_three_text_encoders: - if "text_encoder_2" in self.pipeline_class._lora_loadable_modules: - self.assertTrue( - check_if_lora_correctly_set(pipe_from_pretrained.text_encoder_2), - "Lora not correctly set in text encoder 2", - ) - - images_lora_save_pretrained = pipe_from_pretrained(**inputs, generator=torch.manual_seed(0))[0] - - self.assertTrue( - np.allclose(images_lora, images_lora_save_pretrained, atol=1e-3, rtol=1e-3), - "Loading from saved checkpoints should give same results.", - ) - - def test_simple_inference_with_text_denoiser_lora_save_load(self): - """ - Tests a simple usecase where users could use saving utilities for LoRA for Unet + text encoder - """ - for scheduler_cls in self.scheduler_classes: - components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls) - pipe = self.pipeline_class(**components) - pipe = pipe.to(torch_device) - pipe.set_progress_bar_config(disable=None) - _, _, inputs = self.get_dummy_inputs(with_generator=False) - - output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] - self.assertTrue(output_no_lora.shape == self.output_shape) - - pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config) - - images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] - - with tempfile.TemporaryDirectory() as tmpdirname: - modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True) - lora_state_dicts = self._get_lora_state_dicts(modules_to_save) - self.pipeline_class.save_lora_weights( - save_directory=tmpdirname, safe_serialization=False, **lora_state_dicts - ) - - self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin"))) - pipe.unload_lora_weights() - pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.bin")) - - for module_name, module in modules_to_save.items(): - self.assertTrue(check_if_lora_correctly_set(module), f"Lora not correctly set in {module_name}") - - images_lora_from_pretrained = pipe(**inputs, generator=torch.manual_seed(0))[0] - self.assertTrue( - np.allclose(images_lora, images_lora_from_pretrained, atol=1e-3, rtol=1e-3), - "Loading from saved checkpoints should give same results.", - ) - - def test_simple_inference_with_text_denoiser_lora_and_scale(self): - """ - Tests a simple inference with lora attached on the text encoder + Unet + scale argument - and makes sure it works as expected - """ - attention_kwargs_name = determine_attention_kwargs_name(self.pipeline_class) - - for scheduler_cls in self.scheduler_classes: - components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls) - pipe = self.pipeline_class(**components) - pipe = pipe.to(torch_device) - pipe.set_progress_bar_config(disable=None) - _, _, inputs = self.get_dummy_inputs(with_generator=False) - - output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] - self.assertTrue(output_no_lora.shape == self.output_shape) - - pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config) - - output_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] - self.assertTrue( - not np.allclose(output_lora, output_no_lora, atol=1e-3, rtol=1e-3), "Lora should change the output" - ) - - attention_kwargs = {attention_kwargs_name: {"scale": 0.5}} - output_lora_scale = pipe(**inputs, generator=torch.manual_seed(0), **attention_kwargs)[0] - - self.assertTrue( - not np.allclose(output_lora, output_lora_scale, atol=1e-3, rtol=1e-3), - "Lora + scale should change the output", - ) - - attention_kwargs = {attention_kwargs_name: {"scale": 0.0}} - output_lora_0_scale = pipe(**inputs, generator=torch.manual_seed(0), **attention_kwargs)[0] - - self.assertTrue( - np.allclose(output_no_lora, output_lora_0_scale, atol=1e-3, rtol=1e-3), - "Lora + 0 scale should lead to same result as no LoRA", - ) - - if "text_encoder" in self.pipeline_class._lora_loadable_modules: - self.assertTrue( - pipe.text_encoder.text_model.encoder.layers[0].self_attn.q_proj.scaling["default"] == 1.0, - "The scaling parameter has not been correctly restored!", - ) - - def test_simple_inference_with_text_lora_denoiser_fused(self): - """ - Tests a simple inference with lora attached into text encoder + fuses the lora weights into base model - and makes sure it works as expected - with unet - """ - for scheduler_cls in self.scheduler_classes: - components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls) - pipe = self.pipeline_class(**components) - pipe = pipe.to(torch_device) - pipe.set_progress_bar_config(disable=None) - _, _, inputs = self.get_dummy_inputs(with_generator=False) - - output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] - self.assertTrue(output_no_lora.shape == self.output_shape) - - pipe, denoiser = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config) - - pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules) - - # Fusing should still keep the LoRA layers - if "text_encoder" in self.pipeline_class._lora_loadable_modules: - self.assertTrue( - check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder" - ) - - self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser") - - if self.has_two_text_encoders or self.has_three_text_encoders: - if "text_encoder_2" in self.pipeline_class._lora_loadable_modules: - self.assertTrue( - check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" - ) - - output_fused = pipe(**inputs, generator=torch.manual_seed(0))[0] - self.assertFalse( - np.allclose(output_fused, output_no_lora, atol=1e-3, rtol=1e-3), "Fused lora should change the output" - ) - - def test_simple_inference_with_text_denoiser_lora_unloaded(self): - """ - Tests a simple inference with lora attached to text encoder and unet, then unloads the lora weights - and makes sure it works as expected - """ - for scheduler_cls in self.scheduler_classes: - components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls) - pipe = self.pipeline_class(**components) - pipe = pipe.to(torch_device) - pipe.set_progress_bar_config(disable=None) - _, _, inputs = self.get_dummy_inputs(with_generator=False) - - output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] - self.assertTrue(output_no_lora.shape == self.output_shape) - - pipe, denoiser = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config) - - pipe.unload_lora_weights() - # unloading should remove the LoRA layers - self.assertFalse( - check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly unloaded in text encoder" - ) - self.assertFalse(check_if_lora_correctly_set(denoiser), "Lora not correctly unloaded in denoiser") - - if self.has_two_text_encoders or self.has_three_text_encoders: - if "text_encoder_2" in self.pipeline_class._lora_loadable_modules: - self.assertFalse( - check_if_lora_correctly_set(pipe.text_encoder_2), - "Lora not correctly unloaded in text encoder 2", - ) - - output_unloaded = pipe(**inputs, generator=torch.manual_seed(0))[0] - self.assertTrue( - np.allclose(output_unloaded, output_no_lora, atol=1e-3, rtol=1e-3), - "Fused lora should change the output", - ) - - def test_simple_inference_with_text_denoiser_lora_unfused( - self, expected_atol: float = 1e-3, expected_rtol: float = 1e-3 - ): - """ - Tests a simple inference with lora attached to text encoder and unet, then unloads the lora weights - and makes sure it works as expected - """ - for scheduler_cls in self.scheduler_classes: - components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls) - pipe = self.pipeline_class(**components) - pipe = pipe.to(torch_device) - pipe.set_progress_bar_config(disable=None) - _, _, inputs = self.get_dummy_inputs(with_generator=False) - - pipe, denoiser = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config) - - pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules) - self.assertTrue(pipe.num_fused_loras == 1, f"{pipe.num_fused_loras=}, {pipe.fused_loras=}") - output_fused_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] - - pipe.unfuse_lora(components=self.pipeline_class._lora_loadable_modules) - self.assertTrue(pipe.num_fused_loras == 0, f"{pipe.num_fused_loras=}, {pipe.fused_loras=}") - output_unfused_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] - - # unloading should remove the LoRA layers - if "text_encoder" in self.pipeline_class._lora_loadable_modules: - self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Unfuse should still keep LoRA layers") - - self.assertTrue(check_if_lora_correctly_set(denoiser), "Unfuse should still keep LoRA layers") - - if self.has_two_text_encoders or self.has_three_text_encoders: - if "text_encoder_2" in self.pipeline_class._lora_loadable_modules: - self.assertTrue( - check_if_lora_correctly_set(pipe.text_encoder_2), "Unfuse should still keep LoRA layers" - ) - - # Fuse and unfuse should lead to the same results - self.assertTrue( - np.allclose(output_fused_lora, output_unfused_lora, atol=expected_atol, rtol=expected_rtol), - "Fused lora should not change the output", - ) - - def test_simple_inference_with_text_denoiser_multi_adapter(self): - """ - Tests a simple inference with lora attached to text encoder and unet, attaches - multiple adapters and set them - """ - for scheduler_cls in self.scheduler_classes: - components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls) - pipe = self.pipeline_class(**components) - pipe = pipe.to(torch_device) - pipe.set_progress_bar_config(disable=None) - _, _, inputs = self.get_dummy_inputs(with_generator=False) - - output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] - - if "text_encoder" in self.pipeline_class._lora_loadable_modules: - pipe.text_encoder.add_adapter(text_lora_config, "adapter-1") - pipe.text_encoder.add_adapter(text_lora_config, "adapter-2") - self.assertTrue( - check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder" - ) - - denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet - denoiser.add_adapter(denoiser_lora_config, "adapter-1") - denoiser.add_adapter(denoiser_lora_config, "adapter-2") - self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.") - - if self.has_two_text_encoders or self.has_three_text_encoders: - if "text_encoder_2" in self.pipeline_class._lora_loadable_modules: - pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1") - pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-2") - self.assertTrue( - check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" - ) - - pipe.set_adapters("adapter-1") - output_adapter_1 = pipe(**inputs, generator=torch.manual_seed(0))[0] - self.assertFalse( - np.allclose(output_no_lora, output_adapter_1, atol=1e-3, rtol=1e-3), - "Adapter outputs should be different.", - ) - - pipe.set_adapters("adapter-2") - output_adapter_2 = pipe(**inputs, generator=torch.manual_seed(0))[0] - self.assertFalse( - np.allclose(output_no_lora, output_adapter_2, atol=1e-3, rtol=1e-3), - "Adapter outputs should be different.", - ) - - pipe.set_adapters(["adapter-1", "adapter-2"]) - output_adapter_mixed = pipe(**inputs, generator=torch.manual_seed(0))[0] - self.assertFalse( - np.allclose(output_no_lora, output_adapter_mixed, atol=1e-3, rtol=1e-3), - "Adapter outputs should be different.", - ) - - # Fuse and unfuse should lead to the same results - self.assertFalse( - np.allclose(output_adapter_1, output_adapter_2, atol=1e-3, rtol=1e-3), - "Adapter 1 and 2 should give different results", - ) - - self.assertFalse( - np.allclose(output_adapter_1, output_adapter_mixed, atol=1e-3, rtol=1e-3), - "Adapter 1 and mixed adapters should give different results", - ) - - self.assertFalse( - np.allclose(output_adapter_2, output_adapter_mixed, atol=1e-3, rtol=1e-3), - "Adapter 2 and mixed adapters should give different results", - ) - - pipe.disable_lora() - output_disabled = pipe(**inputs, generator=torch.manual_seed(0))[0] - - self.assertTrue( - np.allclose(output_no_lora, output_disabled, atol=1e-3, rtol=1e-3), - "output with no lora and output with lora disabled should give same results", - ) - - def test_wrong_adapter_name_raises_error(self): - adapter_name = "adapter-1" - - scheduler_cls = self.scheduler_classes[0] - components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls) - pipe = self.pipeline_class(**components) - pipe = pipe.to(torch_device) - pipe.set_progress_bar_config(disable=None) - _, _, inputs = self.get_dummy_inputs(with_generator=False) - - pipe, _ = self.add_adapters_to_pipeline( - pipe, text_lora_config, denoiser_lora_config, adapter_name=adapter_name - ) - - with self.assertRaises(ValueError) as err_context: - pipe.set_adapters("test") - - self.assertTrue("not in the list of present adapters" in str(err_context.exception)) - - # test this works. - pipe.set_adapters(adapter_name) - _ = pipe(**inputs, generator=torch.manual_seed(0))[0] - - def test_multiple_wrong_adapter_name_raises_error(self): - adapter_name = "adapter-1" - scheduler_cls = self.scheduler_classes[0] - components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls) - pipe = self.pipeline_class(**components) - pipe = pipe.to(torch_device) - pipe.set_progress_bar_config(disable=None) - _, _, inputs = self.get_dummy_inputs(with_generator=False) - - pipe, _ = self.add_adapters_to_pipeline( - pipe, text_lora_config, denoiser_lora_config, adapter_name=adapter_name - ) - - scale_with_wrong_components = {"foo": 0.0, "bar": 0.0, "tik": 0.0} - logger = logging.get_logger("diffusers.loaders.lora_base") - logger.setLevel(30) - with CaptureLogger(logger) as cap_logger: - pipe.set_adapters(adapter_name, adapter_weights=scale_with_wrong_components) - - wrong_components = sorted(set(scale_with_wrong_components.keys())) - msg = f"The following components in `adapter_weights` are not part of the pipeline: {wrong_components}. " - self.assertTrue(msg in str(cap_logger.out)) - - # test this works. - pipe.set_adapters(adapter_name) - _ = pipe(**inputs, generator=torch.manual_seed(0))[0] - - def test_simple_inference_with_text_denoiser_block_scale(self): - """ - Tests a simple inference with lora attached to text encoder and unet, attaches - one adapter and set different weights for different blocks (i.e. block lora) - """ - for scheduler_cls in self.scheduler_classes: - components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls) - pipe = self.pipeline_class(**components) - pipe = pipe.to(torch_device) - pipe.set_progress_bar_config(disable=None) - _, _, inputs = self.get_dummy_inputs(with_generator=False) - - output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] + "meta" in {p.device.type for p in pipe.text_encoder.parameters()}, + "The LoRA params should be on 'meta' device.", + ) - pipe.text_encoder.add_adapter(text_lora_config, "adapter-1") - self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder") + te_state_dict = initialize_dummy_state_dict(get_peft_model_state_dict(pipe.text_encoder)) + set_peft_model_state_dict(pipe.text_encoder, te_state_dict, low_cpu_mem_usage=True) + self.assertTrue( + "meta" not in {p.device.type for p in pipe.text_encoder.parameters()}, + "No param should be on 'meta' device.", + ) denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet - denoiser.add_adapter(denoiser_lora_config) + inject_adapter_in_model(denoiser_lora_config, denoiser, low_cpu_mem_usage=True) self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.") + self.assertTrue( + "meta" in {p.device.type for p in denoiser.parameters()}, "The LoRA params should be on 'meta' device." + ) + + denoiser_state_dict = initialize_dummy_state_dict(get_peft_model_state_dict(denoiser)) + set_peft_model_state_dict(denoiser, denoiser_state_dict, low_cpu_mem_usage=True) + self.assertTrue( + "meta" not in {p.device.type for p in denoiser.parameters()}, "No param should be on 'meta' device." + ) if self.has_two_text_encoders or self.has_three_text_encoders: if "text_encoder_2" in self.pipeline_class._lora_loadable_modules: - pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1") + inject_adapter_in_model(text_lora_config, pipe.text_encoder_2, low_cpu_mem_usage=True) self.assertTrue( check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" ) + self.assertTrue( + "meta" in {p.device.type for p in pipe.text_encoder_2.parameters()}, + "The LoRA params should be on 'meta' device.", + ) - weights_1 = {"text_encoder": 2, "unet": {"down": 5}} - pipe.set_adapters("adapter-1", weights_1) - output_weights_1 = pipe(**inputs, generator=torch.manual_seed(0))[0] - - weights_2 = {"unet": {"up": 5}} - pipe.set_adapters("adapter-1", weights_2) - output_weights_2 = pipe(**inputs, generator=torch.manual_seed(0))[0] + te2_state_dict = initialize_dummy_state_dict(get_peft_model_state_dict(pipe.text_encoder_2)) + set_peft_model_state_dict(pipe.text_encoder_2, te2_state_dict, low_cpu_mem_usage=True) + self.assertTrue( + "meta" not in {p.device.type for p in pipe.text_encoder_2.parameters()}, + "No param should be on 'meta' device.", + ) - self.assertFalse( - np.allclose(output_weights_1, output_weights_2, atol=1e-3, rtol=1e-3), - "LoRA weights 1 and 2 should give different results", - ) - self.assertFalse( - np.allclose(output_no_lora, output_weights_1, atol=1e-3, rtol=1e-3), - "No adapter and LoRA weights 1 should give different results", - ) - self.assertFalse( - np.allclose(output_no_lora, output_weights_2, atol=1e-3, rtol=1e-3), - "No adapter and LoRA weights 2 should give different results", - ) + _, _, inputs = self.get_dummy_inputs() + output_lora = pipe(**inputs)[0] + self.assertTrue(output_lora.shape == self.output_shape) - pipe.disable_lora() - output_disabled = pipe(**inputs, generator=torch.manual_seed(0))[0] + @require_transformers_version_greater("4.45.2") + def test_low_cpu_mem_usage_with_loading(self): + """Tests if we can load LoRA state dict with low_cpu_mem_usage.""" - self.assertTrue( - np.allclose(output_no_lora, output_disabled, atol=1e-3, rtol=1e-3), - "output with no lora and output with lora disabled should give same results", + for scheduler_cls in self.scheduler_classes: + pipe, inputs, _, text_lora_config, denoiser_lora_config = self._setup_pipeline_and_get_base_output( + scheduler_cls ) + pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config) + images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] - def test_simple_inference_with_text_denoiser_multi_adapter_block_lora(self): - """ - Tests a simple inference with lora attached to text encoder and unet, attaches - multiple adapters and set different weights for different blocks (i.e. block lora) - """ - for scheduler_cls in self.scheduler_classes: - components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls) - pipe = self.pipeline_class(**components) - pipe = pipe.to(torch_device) - pipe.set_progress_bar_config(disable=None) - _, _, inputs = self.get_dummy_inputs(with_generator=False) + with tempfile.TemporaryDirectory() as tmpdirname: + modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True) + lora_state_dicts = self._get_lora_state_dicts(modules_to_save) + self.pipeline_class.save_lora_weights( + save_directory=tmpdirname, safe_serialization=False, **lora_state_dicts + ) - output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] + self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin"))) + pipe.unload_lora_weights() + pipe.load_lora_weights(tmpdirname, low_cpu_mem_usage=False) - if "text_encoder" in self.pipeline_class._lora_loadable_modules: - pipe.text_encoder.add_adapter(text_lora_config, "adapter-1") - pipe.text_encoder.add_adapter(text_lora_config, "adapter-2") + for module_name, module in modules_to_save.items(): + self.assertTrue(check_if_lora_correctly_set(module), f"Lora not correctly set in {module_name}") + + images_lora_from_pretrained = pipe(**inputs, generator=torch.manual_seed(0))[0] self.assertTrue( - check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder" + np.allclose(images_lora, images_lora_from_pretrained, atol=1e-3, rtol=1e-3), + "Loading from saved checkpoints should give same results.", ) - denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet - denoiser.add_adapter(denoiser_lora_config, "adapter-1") - denoiser.add_adapter(denoiser_lora_config, "adapter-2") - self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.") + # Now, check for `low_cpu_mem_usage.` + pipe.unload_lora_weights() + pipe.load_lora_weights(tmpdirname, low_cpu_mem_usage=True) - if self.has_two_text_encoders or self.has_three_text_encoders: - if "text_encoder_2" in self.pipeline_class._lora_loadable_modules: - pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1") - pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-2") - self.assertTrue( - check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" - ) + for module_name, module in modules_to_save.items(): + self.assertTrue(check_if_lora_correctly_set(module), f"Lora not correctly set in {module_name}") - scales_1 = {"text_encoder": 2, "unet": {"down": 5}} - scales_2 = {"unet": {"down": 5, "mid": 5}} + images_lora_from_pretrained_low_cpu = pipe(**inputs, generator=torch.manual_seed(0))[0] + self.assertTrue( + np.allclose( + images_lora_from_pretrained_low_cpu, images_lora_from_pretrained, atol=1e-3, rtol=1e-3 + ), + "Loading from saved checkpoints with `low_cpu_mem_usage` should give same results.", + ) - pipe.set_adapters("adapter-1", scales_1) - output_adapter_1 = pipe(**inputs, generator=torch.manual_seed(0))[0] + def test_simple_inference_save_pretrained_with_text_lora(self): + """ + Tests a simple usecase where users could use saving utilities for text encoder (only) + LoRA through save_pretrained. + """ + if not any("text_encoder" in k for k in self.pipeline_class._lora_loadable_modules): + pytest.skip("Test not supported.") + for scheduler_cls in self.scheduler_classes: + pipe, inputs, _, text_lora_config, _ = self._setup_pipeline_and_get_base_output(scheduler_cls) + pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None) + images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] - pipe.set_adapters("adapter-2", scales_2) - output_adapter_2 = pipe(**inputs, generator=torch.manual_seed(0))[0] + with tempfile.TemporaryDirectory() as tmpdirname: + pipe.save_pretrained(tmpdirname) + pipe_from_pretrained = self.pipeline_class.from_pretrained(tmpdirname) + pipe_from_pretrained.to(torch_device) + modules_to_save = self._get_modules_to_save(pipe, has_denoiser=False) - pipe.set_adapters(["adapter-1", "adapter-2"], [scales_1, scales_2]) - output_adapter_mixed = pipe(**inputs, generator=torch.manual_seed(0))[0] + for module_name, module in modules_to_save.items(): + self.assertTrue(check_if_lora_correctly_set(module), f"Lora not correctly set in {module_name}") - # Fuse and unfuse should lead to the same results - self.assertFalse( - np.allclose(output_adapter_1, output_adapter_2, atol=1e-3, rtol=1e-3), - "Adapter 1 and 2 should give different results", + images_lora_save_pretrained = pipe_from_pretrained(**inputs, generator=torch.manual_seed(0))[0] + self.assertTrue( + np.allclose(images_lora, images_lora_save_pretrained, atol=1e-3, rtol=1e-3), + "Loading from saved checkpoints should give same results.", ) - self.assertFalse( - np.allclose(output_adapter_1, output_adapter_mixed, atol=1e-3, rtol=1e-3), - "Adapter 1 and mixed adapters should give different results", + def test_simple_inference_with_partial_text_lora(self): + """ + Tests a simple inference with lora attached on the text encoder + with different ranks and some adapters removed + and makes sure it works as expected + """ + if not any("text_encoder" in k for k in self.pipeline_class._lora_loadable_modules): + pytest.skip("Test not supported.") + for scheduler_cls in self.scheduler_classes: + pipe, inputs, output_no_lora, _, _ = self._setup_pipeline_and_get_base_output(scheduler_cls) + # Verify `StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder` handles different ranks per module (PR#8324). + text_lora_config = LoraConfig( + r=4, + rank_pattern={self.text_encoder_target_modules[i]: i + 1 for i in range(3)}, + lora_alpha=4, + target_modules=self.text_encoder_target_modules, + init_lora_weights=False, + use_dora=False, ) + pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None) - self.assertFalse( - np.allclose(output_adapter_2, output_adapter_mixed, atol=1e-3, rtol=1e-3), - "Adapter 2 and mixed adapters should give different results", - ) + state_dict = {} + if "text_encoder" in self.pipeline_class._lora_loadable_modules: + # Gather the state dict for the PEFT model, excluding `layers.4`, to ensure `load_lora_into_text_encoder` + # supports missing layers (PR#8324). + state_dict = { + f"text_encoder.{module_name}": param + for module_name, param in get_peft_model_state_dict(pipe.text_encoder).items() + if "text_model.encoder.layers.4" not in module_name + } - pipe.disable_lora() - output_disabled = pipe(**inputs, generator=torch.manual_seed(0))[0] + if self.has_two_text_encoders or self.has_three_text_encoders: + if "text_encoder_2" in self.pipeline_class._lora_loadable_modules: + state_dict.update( + { + f"text_encoder_2.{module_name}": param + for module_name, param in get_peft_model_state_dict(pipe.text_encoder_2).items() + if "text_model.encoder.layers.4" not in module_name + } + ) + output_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] self.assertTrue( - np.allclose(output_no_lora, output_disabled, atol=1e-3, rtol=1e-3), - "output with no lora and output with lora disabled should give same results", + not np.allclose(output_lora, output_no_lora, atol=1e-3, rtol=1e-3), "Lora should change the output" ) - # a mismatching number of adapter_names and adapter_weights should raise an error - with self.assertRaises(ValueError): - pipe.set_adapters(["adapter-1", "adapter-2"], [scales_1]) + # Unload lora and load it back using the pipe.load_lora_weights machinery + pipe.unload_lora_weights() + pipe.load_lora_weights(state_dict) + + output_partial_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] + self.assertTrue( + not np.allclose(output_partial_lora, output_lora, atol=1e-3, rtol=1e-3), + "Removing adapters should change the output", + ) def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self): """Tests that any valid combination of lora block scales can be used in pipe.set_adapter""" + if self.unet_kwargs is None: + pytest.skip("Test not supported.") def updown_options(blocks_with_tf, layers_per_block, value): """ @@ -1260,14 +382,10 @@ def all_possible_dict_opts(unet, value): return opts - components, text_lora_config, denoiser_lora_config = self.get_dummy_components(self.scheduler_cls) - pipe = self.pipeline_class(**components) - pipe = pipe.to(torch_device) - pipe.set_progress_bar_config(disable=None) - _, _, inputs = self.get_dummy_inputs(with_generator=False) - + pipe, _, _, text_lora_config, denoiser_lora_config = self._setup_pipeline_and_get_base_output( + self.scheduler_classes[0] + ) pipe.text_encoder.add_adapter(text_lora_config, "adapter-1") - denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet denoiser.add_adapter(denoiser_lora_config, "adapter-1") @@ -1283,135 +401,212 @@ def all_possible_dict_opts(unet, value): pipe.set_adapters("adapter-1", scale_dict) # test will fail if this line throws an error - def test_simple_inference_with_text_denoiser_multi_adapter_delete_adapter(self): + def test_simple_inference_with_dora(self): + for scheduler_cls in self.scheduler_classes: + pipe, inputs, output_no_dora_lora, text_lora_config, denoiser_lora_config = ( + self._setup_pipeline_and_get_base_output(scheduler_cls, use_dora=True) + ) + pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config) + output_dora_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] + + self.assertFalse( + np.allclose(output_dora_lora, output_no_dora_lora, atol=1e-3, rtol=1e-3), + "DoRA lora should change the output", + ) + + def _test_lora_actions(self, action, lora_components_to_add, expected_atol=1e-3, expected_rtol=1e-3): """ - Tests a simple inference with lora attached to text encoder and unet, attaches - multiple adapters and set/delete them + A unified test for various LoRA actions (fusing, unloading, saving/loading, etc.) + on different combinations of model components. """ + # Skip text_encoder tests if the pipeline doesn't support it + if lora_components_to_add == "text_encoder_only" and not any( + "text_encoder" in k for k in self.pipeline_class._lora_loadable_modules + ): + pytest.skip(f"Test not supported for {self.__class__.__name__} without a LoRA-compatible text encoder.") + for scheduler_cls in self.scheduler_classes: - components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls) - pipe = self.pipeline_class(**components) - pipe = pipe.to(torch_device) - pipe.set_progress_bar_config(disable=None) - _, _, inputs = self.get_dummy_inputs(with_generator=False) + # 1. Setup pipeline and get base output + pipe, inputs, output_no_lora, text_lora_config, denoiser_lora_config = ( + self._setup_pipeline_and_get_base_output(scheduler_cls) + ) - output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] + # 2. Add LoRA adapters based on the parameterization + if lora_components_to_add == "text_encoder_only": + pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None) + elif lora_components_to_add == "text_and_denoiser": + pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config) + else: + raise ValueError(f"Unknown `lora_components_to_add`: {lora_components_to_add}") + modules_to_save = self._get_modules_to_save( + pipe, has_denoiser=lora_components_to_add != "text_encoder_only" + ) + output_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] + self.assertTrue(not np.allclose(output_lora, output_no_lora, atol=expected_atol, rtol=expected_rtol)) - if "text_encoder" in self.pipeline_class._lora_loadable_modules: - pipe.text_encoder.add_adapter(text_lora_config, "adapter-1") - pipe.text_encoder.add_adapter(text_lora_config, "adapter-2") + # 3. Perform the specified action and assert the outcome + if action == "fused": + pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules) + for module_name, module in modules_to_save.items(): + self.assertTrue(check_if_lora_correctly_set(module), f"Lora not correctly set in {module_name}") + output_after_action = pipe(**inputs, generator=torch.manual_seed(0))[0] self.assertTrue( - check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder" + not np.allclose(output_after_action, output_no_lora, atol=expected_atol, rtol=expected_rtol), + "Fused LoRA should produce a different output from the base model.", ) - denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet - denoiser.add_adapter(denoiser_lora_config, "adapter-1") - denoiser.add_adapter(denoiser_lora_config, "adapter-2") - self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.") - - if self.has_two_text_encoders or self.has_three_text_encoders: - lora_loadable_components = self.pipeline_class._lora_loadable_modules - if "text_encoder_2" in lora_loadable_components: - pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1") - pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-2") + elif action == "unloaded": + pipe.unload_lora_weights() + for module_name, module in modules_to_save.items(): self.assertTrue( - check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" + not check_if_lora_correctly_set(module), f"Lora not correctly set in {module_name}" ) + output_after_action = pipe(**inputs, generator=torch.manual_seed(0))[0] + self.assertTrue( + np.allclose(output_after_action, output_no_lora, atol=expected_atol, rtol=expected_rtol), + "Output after unloading LoRA should match the original output.", + ) - pipe.set_adapters("adapter-1") - output_adapter_1 = pipe(**inputs, generator=torch.manual_seed(0))[0] - - pipe.set_adapters("adapter-2") - output_adapter_2 = pipe(**inputs, generator=torch.manual_seed(0))[0] - - pipe.set_adapters(["adapter-1", "adapter-2"]) - output_adapter_mixed = pipe(**inputs, generator=torch.manual_seed(0))[0] - - self.assertFalse( - np.allclose(output_adapter_1, output_adapter_2, atol=1e-3, rtol=1e-3), - "Adapter 1 and 2 should give different results", - ) - - self.assertFalse( - np.allclose(output_adapter_1, output_adapter_mixed, atol=1e-3, rtol=1e-3), - "Adapter 1 and mixed adapters should give different results", - ) - - self.assertFalse( - np.allclose(output_adapter_2, output_adapter_mixed, atol=1e-3, rtol=1e-3), - "Adapter 2 and mixed adapters should give different results", - ) - - pipe.delete_adapters("adapter-1") - output_deleted_adapter_1 = pipe(**inputs, generator=torch.manual_seed(0))[0] + elif action == "unfused": + pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules) + for module_name, module in modules_to_save.items(): + self.assertTrue(check_if_lora_correctly_set(module), f"Lora not correctly set in {module_name}") + self.assertTrue(pipe.num_fused_loras == 1, f"{pipe.num_fused_loras=}, {pipe.fused_loras=}") + output_fused = pipe(**inputs, generator=torch.manual_seed(0))[0] - self.assertTrue( - np.allclose(output_deleted_adapter_1, output_adapter_2, atol=1e-3, rtol=1e-3), - "Adapter 1 and 2 should give different results", - ) + pipe.unfuse_lora(components=self.pipeline_class._lora_loadable_modules) + self.assertTrue(pipe.num_fused_loras == 0, f"{pipe.num_fused_loras=}, {pipe.fused_loras=}") + for module_name, module in modules_to_save.items(): + self.assertTrue(check_if_lora_correctly_set(module), f"Lora not correctly set in {module_name}") + output_unfused = pipe(**inputs, generator=torch.manual_seed(0))[0] - pipe.delete_adapters("adapter-2") - output_deleted_adapters = pipe(**inputs, generator=torch.manual_seed(0))[0] + self.assertTrue( + np.allclose(output_fused, output_unfused, atol=expected_atol, rtol=expected_rtol), + "Output after unfusing should match the fused output.", + ) - self.assertTrue( - np.allclose(output_no_lora, output_deleted_adapters, atol=1e-3, rtol=1e-3), - "output with no lora and output with lora disabled should give same results", - ) + elif action == "save_load": + with tempfile.TemporaryDirectory() as tmpdirname: + has_denoiser = lora_components_to_add == "text_and_denoiser" + modules_to_save = self._get_modules_to_save(pipe, has_denoiser=has_denoiser) + lora_state_dicts = self._get_lora_state_dicts(modules_to_save) - if "text_encoder" in self.pipeline_class._lora_loadable_modules: - pipe.text_encoder.add_adapter(text_lora_config, "adapter-1") - pipe.text_encoder.add_adapter(text_lora_config, "adapter-2") + self.pipeline_class.save_lora_weights( + save_directory=tmpdirname, safe_serialization=False, **lora_state_dicts + ) + self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin"))) - denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet - denoiser.add_adapter(denoiser_lora_config, "adapter-1") - denoiser.add_adapter(denoiser_lora_config, "adapter-2") - self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.") + pipe.unload_lora_weights() + pipe.load_lora_weights(tmpdirname) + for module_name, module in modules_to_save.items(): + self.assertTrue( + check_if_lora_correctly_set(module), f"Lora not correctly set in {module_name}" + ) - pipe.set_adapters(["adapter-1", "adapter-2"]) - pipe.delete_adapters(["adapter-1", "adapter-2"]) + output_after_action = pipe(**inputs, generator=torch.manual_seed(0))[0] + self.assertTrue( + np.allclose(output_lora, output_after_action, atol=expected_atol, rtol=expected_rtol), + "Loading from a saved checkpoint should yield the same result.", + ) - output_deleted_adapters = pipe(**inputs, generator=torch.manual_seed(0))[0] + elif action == "disable": + pipe.disable_lora() + output_disabled = pipe(**inputs, generator=torch.manual_seed(0))[0] + self.assertTrue( + np.allclose(output_no_lora, output_disabled, atol=1e-3, rtol=1e-3), + "output with no lora and output with lora disabled should give same results", + ) - self.assertTrue( - np.allclose(output_no_lora, output_deleted_adapters, atol=1e-3, rtol=1e-3), - "output with no lora and output with lora disabled should give same results", - ) + @parameterized.expand( + [ + # Test actions on text_encoder LoRA only + ("fused", "text_encoder_only"), + ("unloaded", "text_encoder_only"), + ("save_load", "text_encoder_only"), + # Test actions on both text_encoder and denoiser LoRA + ("fused", "text_and_denoiser"), + ("unloaded", "text_and_denoiser"), + ("unfused", "text_and_denoiser"), + ("save_load", "text_and_denoiser"), + ("disable", "text_and_denoiser"), + ] + ) + def test_lora_actions(self, action, lora_components_to_add, expected_atol=1e-3): + """Tests to check if different LoRA actions like fusion, loading-unloading, etc. + work as expected. + """ + for cls in inspect.getmro(self.__class__): + if "test_lora_actions" in cls.__dict__ and cls is not PeftLoraLoaderMixinTests: + # Skip this test if it is overwritten by child class. We need to do this because parameterized + # materializes the test methods on invocation which cannot be overridden. + return + self._test_lora_actions(action, lora_components_to_add, expected_atol) - def test_simple_inference_with_text_denoiser_multi_adapter_weighted(self): + @parameterized.expand([("text_encoder_only",), ("text_and_denoiser",)]) + def test_lora_scaling(self, lora_components_to_add): """ - Tests a simple inference with lora attached to text encoder and unet, attaches - multiple adapters and set them + Tests inference with LoRA scaling applied via attention_kwargs + for different LoRA configurations. """ + if lora_components_to_add == "text_encoder_only": + if not any("text_encoder" in k for k in self.pipeline_class._lora_loadable_modules): + pytest.skip( + f"Test not supported for {self.__class__.__name__} since there is not text encoder in the LoRA loadable modules." + ) + attention_kwargs_name = determine_attention_kwargs_name(self.pipeline_class) + for scheduler_cls in self.scheduler_classes: - components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls) - pipe = self.pipeline_class(**components) - pipe = pipe.to(torch_device) - pipe.set_progress_bar_config(disable=None) - _, _, inputs = self.get_dummy_inputs(with_generator=False) + pipe, inputs, output_no_lora, text_lora_config, denoiser_lora_config = ( + self._setup_pipeline_and_get_base_output(scheduler_cls) + ) - output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] + # Create a deep copy to ensure a clean state for each iteration + lora_pipe = copy.deepcopy(pipe) - if "text_encoder" in self.pipeline_class._lora_loadable_modules: - pipe.text_encoder.add_adapter(text_lora_config, "adapter-1") - pipe.text_encoder.add_adapter(text_lora_config, "adapter-2") - self.assertTrue( - check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder" - ) + # Add LoRA components based on the parameterization + if lora_components_to_add == "text_encoder_only": + lora_pipe, _ = self.add_adapters_to_pipeline(lora_pipe, text_lora_config, denoiser_lora_config=None) + elif lora_components_to_add == "text_and_denoiser": + lora_pipe, _ = self.add_adapters_to_pipeline(lora_pipe, text_lora_config, denoiser_lora_config) - denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet - denoiser.add_adapter(denoiser_lora_config, "adapter-1") - denoiser.add_adapter(denoiser_lora_config, "adapter-2") - self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.") + # 1. Test base LoRA output + output_lora = lora_pipe(**inputs, generator=torch.manual_seed(0))[0] + self.assertFalse( + np.allclose(output_lora, output_no_lora, atol=1e-3, rtol=1e-3), "LoRA should change the output." + ) - if self.has_two_text_encoders or self.has_three_text_encoders: - lora_loadable_components = self.pipeline_class._lora_loadable_modules - if "text_encoder_2" in lora_loadable_components: - pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1") - pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-2") - self.assertTrue( - check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" + # 2. Test with a scale of 0.5 + attention_kwargs = {attention_kwargs_name: {"scale": 0.5}} + output_lora_scale = lora_pipe(**inputs, generator=torch.manual_seed(0), **attention_kwargs)[0] + self.assertFalse( + np.allclose(output_lora, output_lora_scale, atol=1e-3, rtol=1e-3), + "Using a LoRA scale should change the output.", + ) + + # 3. Test with a scale of 0.0, which should be identical to no LoRA + attention_kwargs = {attention_kwargs_name: {"scale": 0.0}} + output_lora_0_scale = lora_pipe(**inputs, generator=torch.manual_seed(0), **attention_kwargs)[0] + self.assertTrue( + np.allclose(output_no_lora, output_lora_0_scale, atol=1e-3, rtol=1e-3), + "Using a LoRA scale of 0.0 should be the same as no LoRA.", + ) + + # 4. Final check to ensure the scaling parameter is restored + if "text_encoder" in self.pipeline_class._lora_loadable_modules: + # Get the underlying LoRA layer to check its scaling factor + lora_layer = lora_pipe.text_encoder.text_model.encoder.layers[0].self_attn.q_proj + if hasattr(lora_layer, "scaling"): + self.assertEqual( + lora_layer.scaling.get("default", 1.0), + 1.0, + "The scaling parameter was not correctly restored!", ) + def _test_lora_set_adapters_scenarios(self, scenario, expected_atol=1e-3, expected_rtol=1e-3): + for scheduler_cls in self.scheduler_classes: + pipe, inputs, output_no_lora, _ = self._setup_multi_adapter_pipeline(scheduler_cls) + + # Run inference with each adapter individually and mixed pipe.set_adapters("adapter-1") output_adapter_1 = pipe(**inputs, generator=torch.manual_seed(0))[0] @@ -1421,37 +616,132 @@ def test_simple_inference_with_text_denoiser_multi_adapter_weighted(self): pipe.set_adapters(["adapter-1", "adapter-2"]) output_adapter_mixed = pipe(**inputs, generator=torch.manual_seed(0))[0] - # Fuse and unfuse should lead to the same results + # --- Assert base multi-adapter behavior --- + self.assertFalse(np.allclose(output_no_lora, output_adapter_1, atol=expected_atol, rtol=expected_rtol)) + self.assertFalse(np.allclose(output_adapter_1, output_adapter_2, atol=expected_atol, rtol=expected_rtol)) self.assertFalse( - np.allclose(output_adapter_1, output_adapter_2, atol=1e-3, rtol=1e-3), - "Adapter 1 and 2 should give different results", + np.allclose(output_adapter_1, output_adapter_mixed, atol=expected_atol, rtol=expected_rtol) ) - self.assertFalse( - np.allclose(output_adapter_1, output_adapter_mixed, atol=1e-3, rtol=1e-3), - "Adapter 1 and mixed adapters should give different results", - ) + # --- Scenario-specific logic --- + if scenario == "weighted": + pipe.set_adapters(["adapter-1", "adapter-2"], [0.5, 0.6]) + output_weighted = pipe(**inputs, generator=torch.manual_seed(0))[0] + self.assertFalse( + np.allclose(output_weighted, output_adapter_mixed, atol=expected_atol, rtol=expected_rtol) + ) - self.assertFalse( - np.allclose(output_adapter_2, output_adapter_mixed, atol=1e-3, rtol=1e-3), - "Adapter 2 and mixed adapters should give different results", - ) + elif scenario == "block_lora": + scales = {"unet": {"down": 0.8, "mid": 0.5, "up": 0.2}} + pipe.set_adapters("adapter-1", scales) + output_block_lora_1 = pipe(**inputs, generator=torch.manual_seed(0))[0] + self.assertFalse( + np.allclose(output_block_lora_1, output_adapter_mixed, atol=expected_atol, rtol=expected_rtol) + ) - pipe.set_adapters(["adapter-1", "adapter-2"], [0.5, 0.6]) - output_adapter_mixed_weighted = pipe(**inputs, generator=torch.manual_seed(0))[0] + text_encoder_modules = [k for k in self.pipeline_class._lora_loadable_modules if "text_encoder" in k] + if text_encoder_modules: + text_encoder_module_name = text_encoder_modules[0] + scales_2 = {text_encoder_module_name: 2, "unet": {"down": 5}} + pipe.set_adapters("adapter-1", scales_2) + output_block_lora_2 = pipe(**inputs, generator=torch.manual_seed(0))[0] + self.assertFalse( + np.allclose(output_block_lora_2, output_adapter_mixed, atol=expected_atol, rtol=expected_rtol) + ) - self.assertFalse( - np.allclose(output_adapter_mixed_weighted, output_adapter_mixed, atol=1e-3, rtol=1e-3), - "Weighted adapter and mixed adapter should give different results", - ) + elif scenario == "delete_adapter": + pipe.set_adapters("adapter-2") + output_adapter_2 = pipe(**inputs, generator=torch.manual_seed(0))[0] + pipe.delete_adapters("adapter-1") + output_after_delete = pipe(**inputs, generator=torch.manual_seed(0))[0] + # After deleting adapter-1, the output should be the same as using only adapter-2 + self.assertTrue( + np.allclose(output_after_delete, output_adapter_2, atol=expected_atol, rtol=expected_rtol) + ) + + elif scenario == "fused_multi": + # 1. Fuse a single adapter + pipe.set_adapters("adapter-1") + pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, adapter_names=["adapter-1"]) + output_lora_1_fused = pipe(**inputs, generator=torch.manual_seed(0))[0] + self.assertTrue( + np.allclose(output_adapter_1, output_lora_1_fused, atol=expected_atol, rtol=expected_rtol) + ) + pipe.unfuse_lora(components=self.pipeline_class._lora_loadable_modules) - pipe.disable_lora() - output_disabled = pipe(**inputs, generator=torch.manual_seed(0))[0] + # 2. Fuse both adapters + pipe.set_adapters(["adapter-1", "adapter-2"]) + pipe.fuse_lora( + components=self.pipeline_class._lora_loadable_modules, adapter_names=["adapter-1", "adapter-2"] + ) + self.assertTrue(pipe.num_fused_loras == 2, f"{pipe.num_fused_loras=}, {pipe.fused_loras=}") - self.assertTrue( - np.allclose(output_no_lora, output_disabled, atol=1e-3, rtol=1e-3), - "output with no lora and output with lora disabled should give same results", - ) + output_all_lora_fused = pipe(**inputs, generator=torch.manual_seed(0))[0] + self.assertTrue( + np.allclose(output_all_lora_fused, output_adapter_mixed, atol=expected_atol, rtol=expected_rtol) + ) + pipe.unfuse_lora(components=self.pipeline_class._lora_loadable_modules) + self.assertTrue(pipe.num_fused_loras == 0, f"{pipe.num_fused_loras=}, {pipe.fused_loras=}") + + @parameterized.expand([("simple",), ("weighted",), ("block_lora",), ("delete_adapter",), ("fused_multi",)]) + def test_lora_set_adapters_scenarios(self, scenario, expected_atol=1e-3, expected_rtol=1e-3): + """ + A unified test for various multi-adapter (and single-adapter) LoRA scenarios, including weighting, + block scaling, and adapter deletion. + """ + for cls in inspect.getmro(self.__class__): + if "test_lora_set_adapters_scenarios" in cls.__dict__ and cls is not PeftLoraLoaderMixinTests: + # Skip this test if it is overwritten by child class. We need to do this because parameterized + # materializes the test methods on invocation which cannot be overridden. + return + + if scenario == "block_lora" and self.unet_kwargs is None: + pytest.skip(f"Test not supported for {scenario=}.") + + self._test_lora_set_adapters_scenarios(scenario, expected_atol, expected_rtol) + + def test_wrong_adapter_name_raises_error(self): + adapter_name = "adapter-1" + scheduler_cls = self.scheduler_classes[0] + pipe, inputs, _, text_lora_config, denoiser_lora_config = self._setup_pipeline_and_get_base_output( + scheduler_cls + ) + pipe, _ = self.add_adapters_to_pipeline( + pipe, text_lora_config, denoiser_lora_config, adapter_name=adapter_name + ) + + with self.assertRaises(ValueError) as err_context: + pipe.set_adapters("test") + + self.assertTrue("not in the list of present adapters" in str(err_context.exception)) + + # test this works. + pipe.set_adapters(adapter_name) + _ = pipe(**inputs, generator=torch.manual_seed(0))[0] + + def test_multiple_wrong_adapter_name_raises_error(self): + adapter_name = "adapter-1" + scheduler_cls = self.scheduler_classes[0] + pipe, inputs, _, text_lora_config, denoiser_lora_config = self._setup_pipeline_and_get_base_output( + scheduler_cls + ) + pipe, _ = self.add_adapters_to_pipeline( + pipe, text_lora_config, denoiser_lora_config, adapter_name=adapter_name + ) + + scale_with_wrong_components = {"foo": 0.0, "bar": 0.0, "tik": 0.0} + logger = logging.get_logger("diffusers.loaders.lora_base") + logger.setLevel(30) + with CaptureLogger(logger) as cap_logger: + pipe.set_adapters(adapter_name, adapter_weights=scale_with_wrong_components) + + wrong_components = sorted(set(scale_with_wrong_components.keys())) + msg = f"The following components in `adapter_weights` are not part of the pipeline: {wrong_components}. " + self.assertTrue(msg in str(cap_logger.out)) + + # test this works. + pipe.set_adapters(adapter_name) + _ = pipe(**inputs, generator=torch.manual_seed(0))[0] @skip_mps @pytest.mark.xfail( @@ -1461,21 +751,11 @@ def test_simple_inference_with_text_denoiser_multi_adapter_weighted(self): ) def test_lora_fuse_nan(self): for scheduler_cls in self.scheduler_classes: - components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls) - pipe = self.pipeline_class(**components) - pipe = pipe.to(torch_device) - pipe.set_progress_bar_config(disable=None) - _, _, inputs = self.get_dummy_inputs(with_generator=False) - - if "text_encoder" in self.pipeline_class._lora_loadable_modules: - pipe.text_encoder.add_adapter(text_lora_config, "adapter-1") - self.assertTrue( - check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder" - ) + pipe, inputs, _, text_lora_config, denoiser_lora_config = self._setup_pipeline_and_get_base_output( + scheduler_cls + ) - denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet - denoiser.add_adapter(denoiser_lora_config, "adapter-1") - self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.") + pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config, "adapter-1") # corrupt one LoRA weight with `inf` values with torch.no_grad(): @@ -1523,12 +803,9 @@ def test_get_adapters(self): are the expected results """ for scheduler_cls in self.scheduler_classes: - components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls) - pipe = self.pipeline_class(**components) - pipe = pipe.to(torch_device) - pipe.set_progress_bar_config(disable=None) - _, _, inputs = self.get_dummy_inputs(with_generator=False) - + pipe, _, _, text_lora_config, denoiser_lora_config = self._setup_pipeline_and_get_base_output( + scheduler_cls + ) pipe.text_encoder.add_adapter(text_lora_config, "adapter-1") denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet @@ -1552,11 +829,9 @@ def test_get_list_adapters(self): are the expected results """ for scheduler_cls in self.scheduler_classes: - components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls) - pipe = self.pipeline_class(**components) - pipe = pipe.to(torch_device) - pipe.set_progress_bar_config(disable=None) - + pipe, _, _, text_lora_config, denoiser_lora_config = self._setup_pipeline_and_get_base_output( + scheduler_cls + ) # 1. dicts_to_be_checked = {} if "text_encoder" in self.pipeline_class._lora_loadable_modules: @@ -1569,7 +844,6 @@ def test_get_list_adapters(self): else: pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1") dicts_to_be_checked.update({"transformer": ["adapter-1"]}) - self.assertDictEqual(pipe.get_list_adapters(), dicts_to_be_checked) # 2. @@ -1584,12 +858,11 @@ def test_get_list_adapters(self): else: pipe.transformer.add_adapter(denoiser_lora_config, "adapter-2") dicts_to_be_checked.update({"transformer": ["adapter-1", "adapter-2"]}) - self.assertDictEqual(pipe.get_list_adapters(), dicts_to_be_checked) - # 3. pipe.set_adapters(["adapter-1", "adapter-2"]) + # 3. dicts_to_be_checked = {} if "text_encoder" in self.pipeline_class._lora_loadable_modules: dicts_to_be_checked = {"text_encoder": ["adapter-1", "adapter-2"]} @@ -1598,11 +871,7 @@ def test_get_list_adapters(self): dicts_to_be_checked.update({"unet": ["adapter-1", "adapter-2"]}) else: dicts_to_be_checked.update({"transformer": ["adapter-1", "adapter-2"]}) - - self.assertDictEqual( - pipe.get_list_adapters(), - dicts_to_be_checked, - ) + self.assertDictEqual(pipe.get_list_adapters(), dicts_to_be_checked) # 4. dicts_to_be_checked = {} @@ -1618,123 +887,15 @@ def test_get_list_adapters(self): self.assertDictEqual(pipe.get_list_adapters(), dicts_to_be_checked) - @require_peft_version_greater(peft_version="0.6.2") - def test_simple_inference_with_text_lora_denoiser_fused_multi( - self, expected_atol: float = 1e-3, expected_rtol: float = 1e-3 - ): - """ - Tests a simple inference with lora attached into text encoder + fuses the lora weights into base model - and makes sure it works as expected - with unet and multi-adapter case - """ - for scheduler_cls in self.scheduler_classes: - components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls) - pipe = self.pipeline_class(**components) - pipe = pipe.to(torch_device) - pipe.set_progress_bar_config(disable=None) - _, _, inputs = self.get_dummy_inputs(with_generator=False) - - output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] - self.assertTrue(output_no_lora.shape == self.output_shape) - - if "text_encoder" in self.pipeline_class._lora_loadable_modules: - pipe.text_encoder.add_adapter(text_lora_config, "adapter-1") - self.assertTrue( - check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder" - ) - pipe.text_encoder.add_adapter(text_lora_config, "adapter-2") - - denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet - denoiser.add_adapter(denoiser_lora_config, "adapter-1") - self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.") - denoiser.add_adapter(denoiser_lora_config, "adapter-2") - - if self.has_two_text_encoders or self.has_three_text_encoders: - lora_loadable_components = self.pipeline_class._lora_loadable_modules - if "text_encoder_2" in lora_loadable_components: - pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1") - self.assertTrue( - check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" - ) - pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-2") - - # set them to multi-adapter inference mode - pipe.set_adapters(["adapter-1", "adapter-2"]) - outputs_all_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] - - pipe.set_adapters(["adapter-1"]) - outputs_lora_1 = pipe(**inputs, generator=torch.manual_seed(0))[0] - - pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, adapter_names=["adapter-1"]) - self.assertTrue(pipe.num_fused_loras == 1, f"{pipe.num_fused_loras=}, {pipe.fused_loras=}") - - # Fusing should still keep the LoRA layers so output should remain the same - outputs_lora_1_fused = pipe(**inputs, generator=torch.manual_seed(0))[0] - - self.assertTrue( - np.allclose(outputs_lora_1, outputs_lora_1_fused, atol=expected_atol, rtol=expected_rtol), - "Fused lora should not change the output", - ) - - pipe.unfuse_lora(components=self.pipeline_class._lora_loadable_modules) - self.assertTrue(pipe.num_fused_loras == 0, f"{pipe.num_fused_loras=}, {pipe.fused_loras=}") - - if "text_encoder" in self.pipeline_class._lora_loadable_modules: - self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Unfuse should still keep LoRA layers") - - self.assertTrue(check_if_lora_correctly_set(denoiser), "Unfuse should still keep LoRA layers") - - if self.has_two_text_encoders or self.has_three_text_encoders: - if "text_encoder_2" in self.pipeline_class._lora_loadable_modules: - self.assertTrue( - check_if_lora_correctly_set(pipe.text_encoder_2), "Unfuse should still keep LoRA layers" - ) - - pipe.fuse_lora( - components=self.pipeline_class._lora_loadable_modules, adapter_names=["adapter-2", "adapter-1"] - ) - self.assertTrue(pipe.num_fused_loras == 2, f"{pipe.num_fused_loras=}, {pipe.fused_loras=}") - - # Fusing should still keep the LoRA layers - output_all_lora_fused = pipe(**inputs, generator=torch.manual_seed(0))[0] - self.assertTrue( - np.allclose(output_all_lora_fused, outputs_all_lora, atol=expected_atol, rtol=expected_rtol), - "Fused lora should not change the output", - ) - pipe.unfuse_lora(components=self.pipeline_class._lora_loadable_modules) - self.assertTrue(pipe.num_fused_loras == 0, f"{pipe.num_fused_loras=}, {pipe.fused_loras=}") - - def test_lora_scale_kwargs_match_fusion(self, expected_atol: float = 1e-3, expected_rtol: float = 1e-3): + def test_lora_scale_kwargs_match_fusion(self, expected_atol=1e-3, expected_rtol=1e-3): attention_kwargs_name = determine_attention_kwargs_name(self.pipeline_class) for lora_scale in [1.0, 0.8]: for scheduler_cls in self.scheduler_classes: - components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls) - pipe = self.pipeline_class(**components) - pipe = pipe.to(torch_device) - pipe.set_progress_bar_config(disable=None) - _, _, inputs = self.get_dummy_inputs(with_generator=False) - - output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] - self.assertTrue(output_no_lora.shape == self.output_shape) - - if "text_encoder" in self.pipeline_class._lora_loadable_modules: - pipe.text_encoder.add_adapter(text_lora_config, "adapter-1") - self.assertTrue( - check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder" - ) - - denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet - denoiser.add_adapter(denoiser_lora_config, "adapter-1") - self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.") - - if self.has_two_text_encoders or self.has_three_text_encoders: - lora_loadable_components = self.pipeline_class._lora_loadable_modules - if "text_encoder_2" in lora_loadable_components: - pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1") - self.assertTrue( - check_if_lora_correctly_set(pipe.text_encoder_2), - "Lora not correctly set in text encoder 2", - ) + pipe, inputs, output_no_lora, text_lora_config, denoiser_lora_config = ( + self._setup_pipeline_and_get_base_output(scheduler_cls) + ) + pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config, "adapter-1") pipe.set_adapters(["adapter-1"]) attention_kwargs = {attention_kwargs_name: {"scale": lora_scale}} @@ -1758,40 +919,11 @@ def test_lora_scale_kwargs_match_fusion(self, expected_atol: float = 1e-3, expec "LoRA should change the output", ) - @require_peft_version_greater(peft_version="0.9.0") - def test_simple_inference_with_dora(self): - for scheduler_cls in self.scheduler_classes: - components, text_lora_config, denoiser_lora_config = self.get_dummy_components( - scheduler_cls, use_dora=True - ) - pipe = self.pipeline_class(**components) - pipe = pipe.to(torch_device) - pipe.set_progress_bar_config(disable=None) - _, _, inputs = self.get_dummy_inputs(with_generator=False) - - output_no_dora_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] - self.assertTrue(output_no_dora_lora.shape == self.output_shape) - - pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config) - - output_dora_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] - - self.assertFalse( - np.allclose(output_dora_lora, output_no_dora_lora, atol=1e-3, rtol=1e-3), - "DoRA lora should change the output", - ) - def test_missing_keys_warning(self): scheduler_cls = self.scheduler_classes[0] # Skip text encoder check for now as that is handled with `transformers`. - components, _, denoiser_lora_config = self.get_dummy_components(scheduler_cls) - pipe = self.pipeline_class(**components) - pipe = pipe.to(torch_device) - pipe.set_progress_bar_config(disable=None) - - denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet - denoiser.add_adapter(denoiser_lora_config) - self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.") + pipe, _, _, _, denoiser_lora_config = self._setup_pipeline_and_get_base_output(scheduler_cls) + pipe, _ = self.add_adapters_to_pipeline(pipe, None, denoiser_lora_config) with tempfile.TemporaryDirectory() as tmpdirname: modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True) @@ -1821,14 +953,8 @@ def test_missing_keys_warning(self): def test_unexpected_keys_warning(self): scheduler_cls = self.scheduler_classes[0] # Skip text encoder check for now as that is handled with `transformers`. - components, _, denoiser_lora_config = self.get_dummy_components(scheduler_cls) - pipe = self.pipeline_class(**components) - pipe = pipe.to(torch_device) - pipe.set_progress_bar_config(disable=None) - - denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet - denoiser.add_adapter(denoiser_lora_config) - self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.") + pipe, _, _, _, denoiser_lora_config = self._setup_pipeline_and_get_base_output(scheduler_cls) + pipe, _ = self.add_adapters_to_pipeline(pipe, None, denoiser_lora_config) with tempfile.TemporaryDirectory() as tmpdirname: modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True) @@ -1850,30 +976,6 @@ def test_unexpected_keys_warning(self): self.assertTrue(".diffusers_cat" in cap_logger.out) - @unittest.skip("This is failing for now - need to investigate") - def test_simple_inference_with_text_denoiser_lora_unfused_torch_compile(self): - """ - Tests a simple inference with lora attached to text encoder and unet, then unloads the lora weights - and makes sure it works as expected - """ - for scheduler_cls in self.scheduler_classes: - components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls) - pipe = self.pipeline_class(**components) - pipe = pipe.to(torch_device) - pipe.set_progress_bar_config(disable=None) - _, _, inputs = self.get_dummy_inputs(with_generator=False) - - pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config) - - pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True) - pipe.text_encoder = torch.compile(pipe.text_encoder, mode="reduce-overhead", fullgraph=True) - - if self.has_two_text_encoders or self.has_three_text_encoders: - pipe.text_encoder_2 = torch.compile(pipe.text_encoder_2, mode="reduce-overhead", fullgraph=True) - - # Just makes sure it works.. - _ = pipe(**inputs, generator=torch.manual_seed(0))[0] - def test_modify_padding_mode(self): def set_pad_mode(network, mode="circular"): for _, module in network.named_modules(): @@ -1895,13 +997,8 @@ def set_pad_mode(network, mode="circular"): def test_logs_info_when_no_lora_keys_found(self): scheduler_cls = self.scheduler_classes[0] # Skip text encoder check for now as that is handled with `transformers`. - components, _, _ = self.get_dummy_components(scheduler_cls) - pipe = self.pipeline_class(**components) - pipe = pipe.to(torch_device) - pipe.set_progress_bar_config(disable=None) - - _, _, inputs = self.get_dummy_inputs(with_generator=False) - original_out = pipe(**inputs, generator=torch.manual_seed(0))[0] + pipe, inputs, output_no_lora, _, _ = self._setup_pipeline_and_get_base_output(scheduler_cls) + output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] no_op_state_dict = {"lora_foo": torch.tensor(2.0), "lora_bar": torch.tensor(3.0)} logger = logging.get_logger("diffusers.loaders.peft") @@ -1913,7 +1010,7 @@ def test_logs_info_when_no_lora_keys_found(self): denoiser = getattr(pipe, "unet") if self.unet_kwargs is not None else getattr(pipe, "transformer") self.assertTrue(cap_logger.out.startswith(f"No LoRA keys associated to {denoiser.__class__.__name__}")) - self.assertTrue(np.allclose(original_out, out_after_lora_attempt, atol=1e-5, rtol=1e-5)) + self.assertTrue(np.allclose(output_no_lora, out_after_lora_attempt, atol=1e-5, rtol=1e-5)) # test only for text encoder for lora_module in self.pipeline_class._lora_loadable_modules: @@ -1941,15 +1038,9 @@ def test_set_adapters_match_attention_kwargs(self): attention_kwargs_name = determine_attention_kwargs_name(self.pipeline_class) for scheduler_cls in self.scheduler_classes: - components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls) - pipe = self.pipeline_class(**components) - pipe = pipe.to(torch_device) - pipe.set_progress_bar_config(disable=None) - _, _, inputs = self.get_dummy_inputs(with_generator=False) - - output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] - self.assertTrue(output_no_lora.shape == self.output_shape) - + pipe, inputs, output_no_lora, text_lora_config, denoiser_lora_config = ( + self._setup_pipeline_and_get_base_output(scheduler_cls) + ) pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config) lora_scale = 0.5 @@ -1977,13 +1068,11 @@ def test_set_adapters_match_attention_kwargs(self): self.pipeline_class.save_lora_weights( save_directory=tmpdirname, safe_serialization=True, **lora_state_dicts ) - - self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))) - pipe = self.pipeline_class(**components) - pipe = pipe.to(torch_device) - pipe.set_progress_bar_config(disable=None) - pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")) - + # We should not have to set up the pipeline again. `unload_lora_weights()` should work. + # TODO: investigate later. + # pipe.unload_lora_weights() + pipe, _, _, _, _ = self._setup_pipeline_and_get_base_output(scheduler_cls) + pipe.load_lora_weights(os.path.join(tmpdirname)) for module_name, module in modules_to_save.items(): self.assertTrue(check_if_lora_correctly_set(module), f"Lora not correctly set in {module_name}") @@ -2001,14 +1090,10 @@ def test_set_adapters_match_attention_kwargs(self): "Loading from saved checkpoints should give same results as set_adapters().", ) - @require_peft_version_greater("0.13.2") def test_lora_B_bias(self): # Currently, this test is only relevant for Flux Control LoRA as we are not # aware of any other LoRA checkpoint that has its `lora_B` biases trained. - components, _, denoiser_lora_config = self.get_dummy_components(self.scheduler_classes[0]) - pipe = self.pipeline_class(**components) - pipe = pipe.to(torch_device) - pipe.set_progress_bar_config(disable=None) + pipe, inputs, _, _, denoiser_lora_config = self._setup_pipeline_and_get_base_output(self.scheduler_classes[0]) # keep track of the bias values of the base layers to perform checks later. bias_values = {} @@ -2042,13 +1127,9 @@ def test_lora_B_bias(self): self.assertFalse(np.allclose(lora_bias_false_output, lora_bias_true_output, atol=1e-3, rtol=1e-3)) def test_correct_lora_configs_with_different_ranks(self): - components, _, denoiser_lora_config = self.get_dummy_components(self.scheduler_classes[0]) - pipe = self.pipeline_class(**components) - pipe = pipe.to(torch_device) - pipe.set_progress_bar_config(disable=None) - _, _, inputs = self.get_dummy_inputs(with_generator=False) - - original_output = pipe(**inputs, generator=torch.manual_seed(0))[0] + pipe, inputs, original_output, text_lora_config, denoiser_lora_config = ( + self._setup_pipeline_and_get_base_output(self.scheduler_classes[0]) + ) if self.unet_kwargs is not None: pipe.unet.add_adapter(denoiser_lora_config, "adapter-1") @@ -2151,7 +1232,6 @@ def initialize_pipeline(storage_dtype=None, compute_dtype=torch.float32): pipe_float8_e4m3_bf16 = initialize_pipeline(storage_dtype=torch.float8_e4m3fn, compute_dtype=torch.bfloat16) pipe_float8_e4m3_bf16(**inputs, generator=torch.manual_seed(0))[0] - @require_peft_version_greater("0.14.0") def test_layerwise_casting_peft_input_autocast_denoiser(self): r""" A test that checks if layerwise casting works correctly with PEFT layers and forward pass does not fail. This @@ -2194,10 +1274,7 @@ def check_module(denoiser): self.assertTrue(module._diffusers_hook.get_hook(_PEFT_AUTOCAST_DISABLE_HOOK) is not None) # 1. Test forward with add_adapter - components, _, denoiser_lora_config = self.get_dummy_components(self.scheduler_classes[0]) - pipe = self.pipeline_class(**components) - pipe = pipe.to(torch_device, dtype=compute_dtype) - pipe.set_progress_bar_config(disable=None) + pipe, inputs, _, _, denoiser_lora_config = self._setup_pipeline_and_get_base_output(self.scheduler_classes[0]) denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet denoiser.add_adapter(denoiser_lora_config) @@ -2212,7 +1289,6 @@ def check_module(denoiser): ) check_module(denoiser) - _, _, inputs = self.get_dummy_inputs(with_generator=False) pipe(**inputs, generator=torch.manual_seed(0))[0] # 2. Test forward with load_lora_weights @@ -2223,11 +1299,9 @@ def check_module(denoiser): save_directory=tmpdirname, safe_serialization=True, **lora_state_dicts ) - self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))) - components, _, _ = self.get_dummy_components(self.scheduler_classes[0]) - pipe = self.pipeline_class(**components) - pipe = pipe.to(torch_device, dtype=compute_dtype) - pipe.set_progress_bar_config(disable=None) + pipe, inputs, _, _, denoiser_lora_config = self._setup_pipeline_and_get_base_output( + self.scheduler_classes[0] + ) pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")) denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet @@ -2239,17 +1313,14 @@ def check_module(denoiser): ) check_module(denoiser) - _, _, inputs = self.get_dummy_inputs(with_generator=False) pipe(**inputs, generator=torch.manual_seed(0))[0] @parameterized.expand([4, 8, 16]) def test_lora_adapter_metadata_is_loaded_correctly(self, lora_alpha): scheduler_cls = self.scheduler_classes[0] - components, text_lora_config, denoiser_lora_config = self.get_dummy_components( + pipe, _, _, text_lora_config, denoiser_lora_config = self._setup_pipeline_and_get_base_output( scheduler_cls, lora_alpha=lora_alpha ) - pipe = self.pipeline_class(**components) - pipe, _ = self.add_adapters_to_pipeline( pipe, text_lora_config=text_lora_config, denoiser_lora_config=denoiser_lora_config ) @@ -2294,15 +1365,9 @@ def test_lora_adapter_metadata_is_loaded_correctly(self, lora_alpha): @parameterized.expand([4, 8, 16]) def test_lora_adapter_metadata_save_load_inference(self, lora_alpha): scheduler_cls = self.scheduler_classes[0] - components, text_lora_config, denoiser_lora_config = self.get_dummy_components( + pipe, inputs, _, text_lora_config, denoiser_lora_config = self._setup_pipeline_and_get_base_output( scheduler_cls, lora_alpha=lora_alpha ) - pipe = self.pipeline_class(**components).to(torch_device) - _, _, inputs = self.get_dummy_inputs(with_generator=False) - - output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] - self.assertTrue(output_no_lora.shape == self.output_shape) - pipe, _ = self.add_adapters_to_pipeline( pipe, text_lora_config=text_lora_config, denoiser_lora_config=denoiser_lora_config ) @@ -2341,7 +1406,6 @@ def test_lora_unload_add_adapter(self): ) _ = pipe(**inputs, generator=torch.manual_seed(0))[0] - @require_peft_version_greater("0.13.2") def test_lora_exclude_modules(self): """ Test to check if `exclude_modules` works or not. It works in the following way: @@ -2357,7 +1421,6 @@ def test_lora_exclude_modules(self): _, _, inputs = self.get_dummy_inputs(with_generator=False) output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] - self.assertTrue(output_no_lora.shape == self.output_shape) # only supported for `denoiser` now pipe_cp = copy.deepcopy(pipe) @@ -2393,64 +1456,13 @@ def test_lora_exclude_modules(self): "Lora outputs should match.", ) - def test_inference_load_delete_load_adapters(self): - "Tests if `load_lora_weights()` -> `delete_adapters()` -> `load_lora_weights()` works." - for scheduler_cls in self.scheduler_classes: - components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls) - pipe = self.pipeline_class(**components) - pipe = pipe.to(torch_device) - pipe.set_progress_bar_config(disable=None) - _, _, inputs = self.get_dummy_inputs(with_generator=False) - - output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] - - if "text_encoder" in self.pipeline_class._lora_loadable_modules: - pipe.text_encoder.add_adapter(text_lora_config) - self.assertTrue( - check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder" - ) - - denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet - denoiser.add_adapter(denoiser_lora_config) - self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.") - - if self.has_two_text_encoders or self.has_three_text_encoders: - lora_loadable_components = self.pipeline_class._lora_loadable_modules - if "text_encoder_2" in lora_loadable_components: - pipe.text_encoder_2.add_adapter(text_lora_config) - self.assertTrue( - check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" - ) - - output_adapter_1 = pipe(**inputs, generator=torch.manual_seed(0))[0] - - with tempfile.TemporaryDirectory() as tmpdirname: - modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True) - lora_state_dicts = self._get_lora_state_dicts(modules_to_save) - self.pipeline_class.save_lora_weights(save_directory=tmpdirname, **lora_state_dicts) - self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))) - - # First, delete adapter and compare. - pipe.delete_adapters(pipe.get_active_adapters()[0]) - output_no_adapter = pipe(**inputs, generator=torch.manual_seed(0))[0] - self.assertFalse(np.allclose(output_adapter_1, output_no_adapter, atol=1e-3, rtol=1e-3)) - self.assertTrue(np.allclose(output_no_lora, output_no_adapter, atol=1e-3, rtol=1e-3)) - - # Then load adapter and compare. - pipe.load_lora_weights(tmpdirname) - output_lora_loaded = pipe(**inputs, generator=torch.manual_seed(0))[0] - self.assertTrue(np.allclose(output_adapter_1, output_lora_loaded, atol=1e-3, rtol=1e-3)) - def _test_group_offloading_inference_denoiser(self, offload_type, use_stream): from diffusers.hooks.group_offloading import _get_top_level_group_offload_hook onload_device = torch_device offload_device = torch.device("cpu") - components, text_lora_config, denoiser_lora_config = self.get_dummy_components(self.scheduler_classes[0]) - pipe = self.pipeline_class(**components) - pipe = pipe.to(torch_device) - pipe.set_progress_bar_config(disable=None) + pipe, inputs, _, _, denoiser_lora_config = self._setup_pipeline_and_get_base_output(self.scheduler_classes[0]) denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet denoiser.add_adapter(denoiser_lora_config) @@ -2462,17 +1474,14 @@ def _test_group_offloading_inference_denoiser(self, offload_type, use_stream): self.pipeline_class.save_lora_weights( save_directory=tmpdirname, safe_serialization=True, **lora_state_dicts ) - self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))) - components, _, _ = self.get_dummy_components(self.scheduler_classes[0]) - pipe = self.pipeline_class(**components) - pipe = pipe.to(torch_device) - pipe.set_progress_bar_config(disable=None) + pipe, inputs, _, _, denoiser_lora_config = self._setup_pipeline_and_get_base_output( + self.scheduler_classes[0] + ) denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet - - pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")) - check_if_lora_correctly_set(denoiser) - _, _, inputs = self.get_dummy_inputs(with_generator=False) + pipe.unload_lora_weights() + pipe.load_lora_weights(tmpdirname) + self.assertTrue(check_if_lora_correctly_set(denoiser)) # Test group offloading with load_lora_weights denoiser.enable_group_offload( @@ -2490,11 +1499,11 @@ def _test_group_offloading_inference_denoiser(self, offload_type, use_stream): pipe.unload_lora_weights() group_offload_hook_2 = _get_top_level_group_offload_hook(denoiser) self.assertTrue(group_offload_hook_2 is not None) - output_2 = pipe(**inputs, generator=torch.manual_seed(0))[0] # noqa: F841 + _ = pipe(**inputs, generator=torch.manual_seed(0))[0] # noqa: F841 # Add the lora again and check if group offloading works - pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")) - check_if_lora_correctly_set(denoiser) + pipe.load_lora_weights(tmpdirname) + self.assertTrue(check_if_lora_correctly_set(denoiser)) group_offload_hook_3 = _get_top_level_group_offload_hook(denoiser) self.assertTrue(group_offload_hook_3 is not None) output_3 = pipe(**inputs, generator=torch.manual_seed(0))[0] @@ -2510,3 +1519,220 @@ def test_group_offloading_inference_denoiser(self, offload_type, use_stream): # materializes the test methods on invocation which cannot be overridden. return self._test_group_offloading_inference_denoiser(offload_type, use_stream) + + def get_dummy_components(self, scheduler_cls=None, use_dora=False, lora_alpha=None): + if self.unet_kwargs and self.transformer_kwargs: + raise ValueError("Both `unet_kwargs` and `transformer_kwargs` cannot be specified.") + if self.has_two_text_encoders and self.has_three_text_encoders: + raise ValueError("Both `has_two_text_encoders` and `has_three_text_encoders` cannot be True.") + + scheduler_cls = self.scheduler_cls if scheduler_cls is None else scheduler_cls + rank = 4 + lora_alpha = rank if lora_alpha is None else lora_alpha + + torch.manual_seed(0) + if self.unet_kwargs is not None: + unet = UNet2DConditionModel(**self.unet_kwargs) + else: + transformer = self.transformer_cls(**self.transformer_kwargs) + + scheduler = scheduler_cls(**self.scheduler_kwargs) + + torch.manual_seed(0) + vae = self.vae_cls(**self.vae_kwargs) + + text_encoder = self.text_encoder_cls.from_pretrained( + self.text_encoder_id, subfolder=self.text_encoder_subfolder + ) + tokenizer = self.tokenizer_cls.from_pretrained(self.tokenizer_id, subfolder=self.tokenizer_subfolder) + + if self.text_encoder_2_cls is not None: + text_encoder_2 = self.text_encoder_2_cls.from_pretrained( + self.text_encoder_2_id, subfolder=self.text_encoder_2_subfolder + ) + tokenizer_2 = self.tokenizer_2_cls.from_pretrained( + self.tokenizer_2_id, subfolder=self.tokenizer_2_subfolder + ) + + if self.text_encoder_3_cls is not None: + text_encoder_3 = self.text_encoder_3_cls.from_pretrained( + self.text_encoder_3_id, subfolder=self.text_encoder_3_subfolder + ) + tokenizer_3 = self.tokenizer_3_cls.from_pretrained( + self.tokenizer_3_id, subfolder=self.tokenizer_3_subfolder + ) + + text_lora_config = LoraConfig( + r=rank, + lora_alpha=lora_alpha, + target_modules=self.text_encoder_target_modules, + init_lora_weights=False, + use_dora=use_dora, + ) + denoiser_lora_config = LoraConfig( + r=rank, + lora_alpha=lora_alpha, + target_modules=self.denoiser_target_modules, + init_lora_weights=False, + use_dora=use_dora, + ) + + pipeline_components = { + "scheduler": scheduler, + "vae": vae, + "text_encoder": text_encoder, + "tokenizer": tokenizer, + } + # Denoiser + if self.unet_kwargs is not None: + pipeline_components.update({"unet": unet}) + elif self.transformer_kwargs is not None: + pipeline_components.update({"transformer": transformer}) + + # Remaining text encoders. + if self.text_encoder_2_cls is not None: + pipeline_components.update({"tokenizer_2": tokenizer_2, "text_encoder_2": text_encoder_2}) + if self.text_encoder_3_cls is not None: + pipeline_components.update({"tokenizer_3": tokenizer_3, "text_encoder_3": text_encoder_3}) + + # Remaining stuff + init_params = inspect.signature(self.pipeline_class.__init__).parameters + if "safety_checker" in init_params: + pipeline_components.update({"safety_checker": None}) + if "feature_extractor" in init_params: + pipeline_components.update({"feature_extractor": None}) + if "image_encoder" in init_params: + pipeline_components.update({"image_encoder": None}) + + return pipeline_components, text_lora_config, denoiser_lora_config + + @property + def output_shape(self): + raise NotImplementedError + + def get_dummy_inputs(self, with_generator=True): + batch_size = 1 + sequence_length = 10 + num_channels = 4 + sizes = (32, 32) + + generator = torch.manual_seed(0) + noise = floats_tensor((batch_size, num_channels) + sizes) + input_ids = torch.randint(1, sequence_length, size=(batch_size, sequence_length), generator=generator) + + pipeline_inputs = { + "prompt": "A painting of a squirrel eating a burger", + "num_inference_steps": 5, + "guidance_scale": 6.0, + "output_type": "np", + } + if with_generator: + pipeline_inputs.update({"generator": generator}) + + return noise, input_ids, pipeline_inputs + + def add_adapters_to_pipeline(self, pipe, text_lora_config=None, denoiser_lora_config=None, adapter_name="default"): + if text_lora_config is not None and "text_encoder" in self.pipeline_class._lora_loadable_modules: + pipe.text_encoder.add_adapter(text_lora_config, adapter_name=adapter_name) + self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder") + + if denoiser_lora_config is not None: + denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet + denoiser.add_adapter(denoiser_lora_config, adapter_name=adapter_name) + self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.") + else: + denoiser = None + + if text_lora_config is not None and (self.has_two_text_encoders or self.has_three_text_encoders): + if "text_encoder_2" in self.pipeline_class._lora_loadable_modules: + pipe.text_encoder_2.add_adapter(text_lora_config, adapter_name=adapter_name) + self.assertTrue( + check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" + ) + return pipe, denoiser + + def _setup_pipeline_and_get_base_output(self, scheduler_cls, lora_alpha=4, use_dora=False): + components, text_lora_config, denoiser_lora_config = self.get_dummy_components( + scheduler_cls, lora_alpha=lora_alpha, use_dora=use_dora + ) + pipe = self.pipeline_class(**components) + pipe = pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + _, _, inputs = self.get_dummy_inputs(with_generator=False) + + output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] + self.assertTrue(output_no_lora.shape == self.output_shape) + return pipe, inputs, output_no_lora, text_lora_config, denoiser_lora_config + + def _setup_multi_adapter_pipeline(self, scheduler_cls): + """ + Helper to set up a pipeline with two LoRA adapters ("adapter-1", "adapter-2") + attached to the text encoder and denoiser. + """ + pipe, inputs, output_no_lora, text_lora_config, denoiser_lora_config = ( + self._setup_pipeline_and_get_base_output(scheduler_cls) + ) + + # Add adapter-1 + pipe, denoiser = self.add_adapters_to_pipeline( + pipe, text_lora_config, denoiser_lora_config, adapter_name="adapter-1" + ) + # Add adapter-2 + pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config, adapter_name="adapter-2") + + return pipe, inputs, output_no_lora, denoiser + + def _get_lora_state_dicts(self, modules_to_save): + state_dicts = {} + for module_name, module in modules_to_save.items(): + if module is not None: + state_dicts[f"{module_name}_lora_layers"] = get_peft_model_state_dict(module) + return state_dicts + + def _get_lora_adapter_metadata(self, modules_to_save): + metadatas = {} + for module_name, module in modules_to_save.items(): + if module is not None: + metadatas[f"{module_name}_lora_adapter_metadata"] = module.peft_config["default"].to_dict() + return metadatas + + def _get_modules_to_save(self, pipe, has_denoiser=False): + modules_to_save = {} + lora_loadable_modules = self.pipeline_class._lora_loadable_modules + + if ( + "text_encoder" in lora_loadable_modules + and hasattr(pipe, "text_encoder") + and getattr(pipe.text_encoder, "peft_config", None) is not None + ): + modules_to_save["text_encoder"] = pipe.text_encoder + + if ( + "text_encoder_2" in lora_loadable_modules + and hasattr(pipe, "text_encoder_2") + and getattr(pipe.text_encoder_2, "peft_config", None) is not None + ): + modules_to_save["text_encoder_2"] = pipe.text_encoder_2 + + if has_denoiser: + if "unet" in lora_loadable_modules and hasattr(pipe, "unet"): + modules_to_save["unet"] = pipe.unet + + if "transformer" in lora_loadable_modules and hasattr(pipe, "transformer"): + modules_to_save["transformer"] = pipe.transformer + + return modules_to_save + + def _get_exclude_modules(self, pipe): + from diffusers.utils.peft_utils import _derive_exclude_modules + + modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True) + denoiser = "unet" if self.unet_kwargs is not None else "transformer" + modules_to_save = {k: v for k, v in modules_to_save.items() if k == denoiser} + denoiser_lora_state_dict = self._get_lora_state_dicts(modules_to_save)[f"{denoiser}_lora_layers"] + pipe.unload_lora_weights() + denoiser_state_dict = pipe.unet.state_dict() if self.unet_kwargs is not None else pipe.transformer.state_dict() + exclude_modules = _derive_exclude_modules( + denoiser_state_dict, denoiser_lora_state_dict, adapter_name="default" + ) + return exclude_modules