Skip to content

Commit 0cad89d

Browse files
feat: refactor conversion module, add test for svd correctness
1 parent 4af7273 commit 0cad89d

File tree

3 files changed

+102
-43
lines changed

3 files changed

+102
-43
lines changed

invokeai/backend/patches/layers/utils.py

Lines changed: 2 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ def swap_shift_scale_for_linear_weight(weight: torch.Tensor) -> torch.Tensor:
4646
def decomposite_weight_matric_with_rank(
4747
delta: torch.Tensor,
4848
rank: int,
49+
epsilon: float = 1e-8,
4950
) -> Tuple[torch.Tensor, torch.Tensor]:
5051
"""Decompose given matrix with a specified rank."""
5152
U, S, V = torch.svd(delta)
@@ -55,50 +56,9 @@ def decomposite_weight_matric_with_rank(
5556
S_r = S[:rank]
5657
V_r = V[:, :rank]
5758

58-
S_sqrt = torch.sqrt(S_r)
59+
S_sqrt = torch.sqrt(S_r + epsilon) # regularization
5960

6061
up = torch.matmul(U_r, torch.diag(S_sqrt))
6162
down = torch.matmul(torch.diag(S_sqrt), V_r.T)
6263

6364
return up, down
64-
65-
66-
def approximate_flux_adaLN_lora_layer_from_diffusers_state_dict(state_dict: Dict[str, torch.Tensor]) -> LoRALayer:
67-
'''Approximate given diffusers AdaLN loRA layer in our Flux model'''
68-
69-
if not "lora_up.weight" in state_dict:
70-
raise ValueError(f"Unsupported lora format: {state_dict.keys()}, missing lora_up")
71-
72-
if not "lora_down.weight" in state_dict:
73-
raise ValueError(f"Unsupported lora format: {state_dict.keys()}, missing lora_down")
74-
75-
up = state_dict.pop('lora_up.weight')
76-
down = state_dict.pop('lora_down.weight')
77-
78-
dtype = up.dtype
79-
device = up.device
80-
up_shape = up.shape
81-
down_shape = down.shape
82-
83-
# desired low rank
84-
rank = up_shape[1]
85-
86-
# up scaling for more precise
87-
up.double()
88-
down.double()
89-
weight = up.reshape(up.shape[0], -1) @ down.reshape(down.shape[0], -1)
90-
91-
# swap to our linear format
92-
swapped = swap_shift_scale_for_linear_weight(weight)
93-
94-
_up, _down = decomposite_weight_matric_with_rank(swapped, rank)
95-
96-
assert(_up.shape == up_shape)
97-
assert(_down.shape == down_shape)
98-
99-
# down scaling to original dtype, device
100-
state_dict['lora_up.weight'] = _up.to(dtype).to(device=device)
101-
state_dict['lora_down.weight'] = _down.to(dtype).to(device=device)
102-
103-
return LoRALayer.from_state_dict_values(state_dict)
104-

invokeai/backend/patches/lora_conversions/flux_diffusers_lora_conversion_utils.py

Lines changed: 45 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,10 @@
22

33
import torch
44

5+
from invokeai.backend.patches.layers.lora_layer import LoRALayer
56
from invokeai.backend.patches.layers.base_layer_patch import BaseLayerPatch
67
from invokeai.backend.patches.layers.merged_layer_patch import MergedLayerPatch, Range
7-
from invokeai.backend.patches.layers.utils import any_lora_layer_from_state_dict, approximate_flux_adaLN_lora_layer_from_diffusers_state_dict
8+
from invokeai.backend.patches.layers.utils import any_lora_layer_from_state_dict, swap_shift_scale_for_linear_weight, decomposite_weight_matric_with_rank
89
from invokeai.backend.patches.lora_conversions.flux_lora_constants import FLUX_LORA_TRANSFORMER_PREFIX
910
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
1011

@@ -29,6 +30,49 @@ def is_state_dict_likely_in_flux_diffusers_format(state_dict: Dict[str, torch.Te
2930

3031
return all_keys_in_peft_format and all_expected_keys_present
3132

33+
def approximate_flux_adaLN_lora_layer_from_diffusers_state_dict(state_dict: Dict[str, torch.Tensor]) -> LoRALayer:
34+
'''Approximate given diffusers AdaLN loRA layer in our Flux model'''
35+
36+
if not "lora_up.weight" in state_dict:
37+
raise ValueError(f"Unsupported lora format: {state_dict.keys()}, missing lora_up")
38+
39+
if not "lora_down.weight" in state_dict:
40+
raise ValueError(f"Unsupported lora format: {state_dict.keys()}, missing lora_down")
41+
42+
up = state_dict.pop('lora_up.weight')
43+
down = state_dict.pop('lora_down.weight')
44+
45+
# layer-patcher upcast things to f32,
46+
# we want to maintain a better precison for this one
47+
dtype = torch.float32
48+
49+
device = up.device
50+
up_shape = up.shape
51+
down_shape = down.shape
52+
53+
# desired low rank
54+
rank = up_shape[1]
55+
56+
# up scaling for more precise
57+
up = up.to(torch.float32)
58+
down = down.to(torch.float32)
59+
60+
weight = up.reshape(up_shape[0], -1) @ down.reshape(down_shape[0], -1)
61+
62+
# swap to our linear format
63+
swapped = swap_shift_scale_for_linear_weight(weight)
64+
65+
_up, _down = decomposite_weight_matric_with_rank(swapped, rank)
66+
67+
assert(_up.shape == up_shape)
68+
assert(_down.shape == down_shape)
69+
70+
# down scaling to original dtype, device
71+
state_dict['lora_up.weight'] = _up.to(dtype).to(device=device)
72+
state_dict['lora_down.weight'] = _down.to(dtype).to(device=device)
73+
74+
return LoRALayer.from_state_dict_values(state_dict)
75+
3276

3377
def lora_model_from_flux_diffusers_state_dict(
3478
state_dict: Dict[str, torch.Tensor], alpha: float | None

tests/backend/patches/lora_conversions/test_flux_diffusers_lora_conversion_utils.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
11
import pytest
22
import torch
33

4+
5+
from invokeai.backend.patches.layers.utils import swap_shift_scale_for_linear_weight
46
from invokeai.backend.patches.lora_conversions.flux_diffusers_lora_conversion_utils import (
57
is_state_dict_likely_in_flux_diffusers_format,
68
lora_model_from_flux_diffusers_state_dict,
9+
approximate_flux_adaLN_lora_layer_from_diffusers_state_dict,
710
)
811
from invokeai.backend.patches.lora_conversions.flux_lora_constants import FLUX_LORA_TRANSFORMER_PREFIX
912
from tests.backend.patches.lora_conversions.lora_state_dicts.flux_dora_onetrainer_format import (
@@ -78,3 +81,55 @@ def test_lora_model_from_flux_diffusers_state_dict_extra_keys_error():
7881
# Check that an error is raised.
7982
with pytest.raises(AssertionError):
8083
lora_model_from_flux_diffusers_state_dict(state_dict, alpha=8.0)
84+
85+
86+
@pytest.mark.parametrize("layer_sd_keys",[
87+
{}, # no keys
88+
{'lora_A.weight': [1024, 8], 'lora_B.weight': [8, 512]}, # wrong keys
89+
{'lora_up.weight': [1024, 8],}, # missing key
90+
{'lora_down.weight': [8, 512],}, # missing key
91+
])
92+
def test_approximate_adaLN_from_state_dict_should_only_accept_vanilla_LoRA_format(layer_sd_keys: dict[str, list[int]]):
93+
"""Should only accept the valid state dict"""
94+
layer_state_dict = keys_to_mock_state_dict(layer_sd_keys)
95+
96+
with pytest.raises(ValueError):
97+
approximate_flux_adaLN_lora_layer_from_diffusers_state_dict(layer_state_dict)
98+
99+
100+
@pytest.mark.parametrize("dtype, rtol", [
101+
(torch.float32, 1e-4),
102+
(torch.half, 1e-3),
103+
])
104+
def test_approximate_adaLN_from_state_dict_should_work(dtype: torch.dtype, rtol: float, rate: float = 0.99):
105+
"""Test that we should approximate good enough adaLN layer from diffusers state dict.
106+
This should tolorance some kind of errorness respect to input dtype"""
107+
input_dim = 1024
108+
output_dim = 512
109+
rank = 8 # Low rank
110+
total = input_dim * output_dim
111+
112+
up = torch.randn(input_dim, rank, dtype=dtype)
113+
down = torch.randn(rank, output_dim, dtype=dtype)
114+
115+
layer_state_dict = {
116+
'lora_up.weight': up,
117+
'lora_down.weight': down
118+
}
119+
120+
# XXX Layer patcher cast things to f32
121+
original = up.float() @ down.float()
122+
swapped = swap_shift_scale_for_linear_weight(original)
123+
124+
layer = approximate_flux_adaLN_lora_layer_from_diffusers_state_dict(layer_state_dict)
125+
weight = layer.get_weight(original).float()
126+
127+
print(weight.dtype, swapped.dtype, layer.up.dtype)
128+
129+
close_count = torch.isclose(weight, swapped, rtol=rtol).sum().item()
130+
close_rate = close_count / total
131+
132+
assert close_rate > rate
133+
134+
135+

0 commit comments

Comments
 (0)