Skip to content

Commit 26b21ae

Browse files
feat: verify function called while converting model
1 parent 0cad89d commit 26b21ae

File tree

1 file changed

+17
-0
lines changed

1 file changed

+17
-0
lines changed

tests/backend/patches/lora_conversions/test_flux_diffusers_lora_conversion_utils.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
1+
import unittest.mock
12
import pytest
23
import torch
4+
import unittest
35

46

57
from invokeai.backend.patches.layers.utils import swap_shift_scale_for_linear_weight
@@ -131,5 +133,20 @@ def test_approximate_adaLN_from_state_dict_should_work(dtype: torch.dtype, rtol:
131133

132134
assert close_rate > rate
133135

136+
def test_adaLN_should_be_approximated_if_present_while_converting():
137+
"""AdaLN layer should be approximated if existed inside given model"""
138+
state_dict = keys_to_mock_state_dict(flux_diffusers_with_norm_out_state_dict_keys)
134139

140+
adaLN_layer_key = 'final_layer.adaLN_modulation.1'
141+
prefixed_layer_key = FLUX_LORA_TRANSFORMER_PREFIX + adaLN_layer_key
135142

143+
with unittest.mock.patch(
144+
'invokeai.backend.patches.lora_conversions.flux_diffusers_lora_conversion_utils.approximate_flux_adaLN_lora_layer_from_diffusers_state_dict'
145+
) as mock_approximate_func:
146+
model = lora_model_from_flux_diffusers_state_dict(state_dict, alpha=8.0)
147+
148+
# Check that the model has the correct number of LoRA layers.
149+
assert all(k.startswith(FLUX_LORA_TRANSFORMER_PREFIX) for k in model.layers.keys())
150+
151+
assert prefixed_layer_key in model.layers.keys()
152+
assert mock_approximate_func.call_count == 1

0 commit comments

Comments
 (0)