Skip to content

Commit 4af7273

Browse files
feat: approximate adaLN layer for more compatibility
1 parent bb351a6 commit 4af7273

File tree

5 files changed

+114
-83
lines changed

5 files changed

+114
-83
lines changed

invokeai/backend/patches/layers/diffusers_ada_ln_lora_layer.py

Lines changed: 0 additions & 19 deletions
This file was deleted.

invokeai/backend/patches/layers/utils.py

Lines changed: 66 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Dict
1+
from typing import Dict, Tuple
22

33
import torch
44

@@ -10,7 +10,6 @@
1010
from invokeai.backend.patches.layers.lokr_layer import LoKRLayer
1111
from invokeai.backend.patches.layers.lora_layer import LoRALayer
1212
from invokeai.backend.patches.layers.norm_layer import NormLayer
13-
from invokeai.backend.patches.layers.diffusers_ada_ln_lora_layer import DiffusersAdaLN_LoRALayer
1413

1514

1615
def any_lora_layer_from_state_dict(state_dict: Dict[str, torch.Tensor]) -> BaseLayerPatch:
@@ -36,8 +35,70 @@ def any_lora_layer_from_state_dict(state_dict: Dict[str, torch.Tensor]) -> BaseL
3635
raise ValueError(f"Unsupported lora format: {state_dict.keys()}")
3736

3837

39-
def diffusers_adaLN_lora_layer_from_state_dict(state_dict: Dict[str, torch.Tensor]) -> DiffusersAdaLN_LoRALayer:
38+
39+
def swap_shift_scale_for_linear_weight(weight: torch.Tensor) -> torch.Tensor:
40+
"""Swap shift/scale for given linear layer back and forth"""
41+
# In SD3 and Flux implementation of AdaLayerNormContinuous, it split linear projection output into shift, scale;
42+
# while in diffusers it split into scale, shift. This will flip them around
43+
chunk1, chunk2 = weight.chunk(2, dim=0)
44+
return torch.cat([chunk2, chunk1], dim=0)
45+
46+
def decomposite_weight_matric_with_rank(
47+
delta: torch.Tensor,
48+
rank: int,
49+
) -> Tuple[torch.Tensor, torch.Tensor]:
50+
"""Decompose given matrix with a specified rank."""
51+
U, S, V = torch.svd(delta)
52+
53+
# Truncate to rank r:
54+
U_r = U[:, :rank]
55+
S_r = S[:rank]
56+
V_r = V[:, :rank]
57+
58+
S_sqrt = torch.sqrt(S_r)
59+
60+
up = torch.matmul(U_r, torch.diag(S_sqrt))
61+
down = torch.matmul(torch.diag(S_sqrt), V_r.T)
62+
63+
return up, down
64+
65+
66+
def approximate_flux_adaLN_lora_layer_from_diffusers_state_dict(state_dict: Dict[str, torch.Tensor]) -> LoRALayer:
67+
'''Approximate given diffusers AdaLN loRA layer in our Flux model'''
68+
4069
if not "lora_up.weight" in state_dict:
41-
raise ValueError(f"Unsupported lora format: {state_dict.keys()}")
70+
raise ValueError(f"Unsupported lora format: {state_dict.keys()}, missing lora_up")
4271

43-
return DiffusersAdaLN_LoRALayer.from_state_dict_values(state_dict)
72+
if not "lora_down.weight" in state_dict:
73+
raise ValueError(f"Unsupported lora format: {state_dict.keys()}, missing lora_down")
74+
75+
up = state_dict.pop('lora_up.weight')
76+
down = state_dict.pop('lora_down.weight')
77+
78+
dtype = up.dtype
79+
device = up.device
80+
up_shape = up.shape
81+
down_shape = down.shape
82+
83+
# desired low rank
84+
rank = up_shape[1]
85+
86+
# up scaling for more precise
87+
up.double()
88+
down.double()
89+
weight = up.reshape(up.shape[0], -1) @ down.reshape(down.shape[0], -1)
90+
91+
# swap to our linear format
92+
swapped = swap_shift_scale_for_linear_weight(weight)
93+
94+
_up, _down = decomposite_weight_matric_with_rank(swapped, rank)
95+
96+
assert(_up.shape == up_shape)
97+
assert(_down.shape == down_shape)
98+
99+
# down scaling to original dtype, device
100+
state_dict['lora_up.weight'] = _up.to(dtype).to(device=device)
101+
state_dict['lora_down.weight'] = _down.to(dtype).to(device=device)
102+
103+
return LoRALayer.from_state_dict_values(state_dict)
104+

invokeai/backend/patches/lora_conversions/flux_diffusers_lora_conversion_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
from invokeai.backend.patches.layers.base_layer_patch import BaseLayerPatch
66
from invokeai.backend.patches.layers.merged_layer_patch import MergedLayerPatch, Range
7-
from invokeai.backend.patches.layers.utils import any_lora_layer_from_state_dict, diffusers_adaLN_lora_layer_from_state_dict
7+
from invokeai.backend.patches.layers.utils import any_lora_layer_from_state_dict, approximate_flux_adaLN_lora_layer_from_diffusers_state_dict
88
from invokeai.backend.patches.lora_conversions.flux_lora_constants import FLUX_LORA_TRANSFORMER_PREFIX
99
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
1010

@@ -86,7 +86,7 @@ def add_adaLN_lora_layer_if_present(src_key: str, dst_key: str) -> None:
8686
if src_key in grouped_state_dict:
8787
src_layer_dict = grouped_state_dict.pop(src_key)
8888
values = get_lora_layer_values(src_layer_dict)
89-
layers[dst_key] = diffusers_adaLN_lora_layer_from_state_dict(values)
89+
layers[dst_key] = approximate_flux_adaLN_lora_layer_from_diffusers_state_dict(values)
9090

9191
def add_qkv_lora_layer_if_present(
9292
src_keys: list[str],

tests/backend/patches/layers/test_diffuser_ada_ln_lora_layer.py

Lines changed: 0 additions & 57 deletions
This file was deleted.
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
import torch
2+
3+
from invokeai.backend.patches.layers.utils import decomposite_weight_matric_with_rank, swap_shift_scale_for_linear_weight
4+
5+
6+
def test_swap_shift_scale_for_linear_weight():
7+
"""Test that swaping should work"""
8+
original = torch.Tensor([1, 2])
9+
expected = torch.Tensor([2, 1])
10+
11+
swapped = swap_shift_scale_for_linear_weight(original)
12+
assert(torch.allclose(expected, swapped))
13+
14+
size= (3, 4)
15+
first = torch.randn(size)
16+
second = torch.randn(size)
17+
18+
original = torch.concat([first, second])
19+
expected = torch.concat([second, first])
20+
21+
swapped = swap_shift_scale_for_linear_weight(original)
22+
assert(torch.allclose(expected, swapped))
23+
24+
# call this twice will reconstruct the original
25+
reconstructed = swap_shift_scale_for_linear_weight(swapped)
26+
assert(torch.allclose(reconstructed, original))
27+
28+
def test_decomposite_weight_matric_with_rank():
29+
"""Test that decompsition of given matrix into 2 low rank matrices work"""
30+
input_dim = 1024
31+
output_dim = 1024
32+
rank = 8 # Low rank
33+
34+
35+
A = torch.randn(input_dim, rank).double()
36+
B = torch.randn(rank, output_dim).double()
37+
W0 = A @ B
38+
39+
C, D = decomposite_weight_matric_with_rank(W0, rank)
40+
R = C @ D
41+
42+
assert(C.shape == A.shape)
43+
assert(D.shape == B.shape)
44+
45+
assert torch.allclose(W0, R)
46+

0 commit comments

Comments
 (0)