Skip to content

Commit b087694

Browse files
feat: add tests for DiffuserAdaLN layer logic
1 parent 4f1b6ce commit b087694

File tree

2 files changed

+60
-3
lines changed

2 files changed

+60
-3
lines changed

invokeai/backend/patches/layers/diffusers_ada_ln_lora_layer.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,10 @@
22

33
from invokeai.backend.patches.layers.lora_layer import LoRALayer
44

5+
def swap_shift_scale(tensor: torch.Tensor) -> torch.Tensor:
6+
scale, shift = tensor.chunk(2, dim=0)
7+
return torch.cat([shift, scale], dim=0)
8+
59
class DiffusersAdaLN_LoRALayer(LoRALayer):
610
'''LoRA layer converted from Diffusers AdaLN, weight is shift-scale swapped'''
711

@@ -11,6 +15,4 @@ def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
1115
# So we swap the linear projection weights in order to be able to use Flux implementation
1216

1317
weight = super().get_weight(orig_weight)
14-
scale, shift = weight.chunk(2, dim=0)
15-
16-
return torch.cat([shift, scale], dim=0)
18+
return swap_shift_scale(weight)
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
import torch
2+
3+
from invokeai.backend.patches.layers.lora_layer import LoRALayer
4+
from invokeai.backend.patches.layers.diffusers_ada_ln_lora_layer import DiffusersAdaLN_LoRALayer, swap_shift_scale
5+
6+
def test_swap_shift_scale_for_tensor():
7+
"""Test swaping function"""
8+
tensor = torch.Tensor([1, 2])
9+
expected = torch.Tensor([2, 1])
10+
11+
swapped = swap_shift_scale(tensor)
12+
assert(torch.allclose(expected, swapped))
13+
14+
size= (3, 4)
15+
first = torch.randn(size)
16+
second = torch.randn(size)
17+
18+
tensor = torch.concat([first, second])
19+
expected = torch.concat([second, first])
20+
21+
swapped = swap_shift_scale(tensor)
22+
assert(torch.allclose(expected, swapped))
23+
24+
def test_diffusers_adaLN_lora_layer_get_weight():
25+
"""Test getting weight from DiffusersAdaLN_LoRALayer."""
26+
small_in_features = 4
27+
big_in_features = 8
28+
out_features = 16
29+
rank = 4
30+
alpha = 16.0
31+
32+
lora = LoRALayer(
33+
up=torch.ones(out_features, rank),
34+
mid=None,
35+
down=torch.ones(rank, big_in_features),
36+
alpha=alpha,
37+
bias=None
38+
)
39+
layer = DiffusersAdaLN_LoRALayer(
40+
up=torch.ones(out_features, rank),
41+
mid=None,
42+
down=torch.ones(rank, big_in_features),
43+
alpha=alpha,
44+
bias=None
45+
)
46+
47+
# mock original weight, normally ignored in our loRA
48+
orig_weight = torch.ones(small_in_features)
49+
50+
diffuser_weight = layer.get_weight(orig_weight)
51+
lora_weight = lora.get_weight(orig_weight)
52+
53+
# diffusers lora weight should be flipped
54+
assert(torch.allclose(diffuser_weight, swap_shift_scale(lora_weight)))
55+

0 commit comments

Comments
 (0)