Skip to content

Commit bb351a6

Browse files
feat: refine swap logic
1 parent b5865fe commit bb351a6

File tree

2 files changed

+10
-7
lines changed

2 files changed

+10
-7
lines changed

invokeai/backend/patches/layers/diffusers_ada_ln_lora_layer.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,5 +14,6 @@ def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
1414
# while in diffusers it split into scale, shift.
1515
# So we swap the linear projection weights in order to be able to use Flux implementation
1616

17-
weight = super().get_weight(orig_weight)
18-
return swap_shift_scale(weight)
17+
weight = super().get_weight(orig_weight)
18+
weight = swap_shift_scale(weight)
19+
return weight

tests/backend/patches/layers/test_diffuser_ada_ln_lora_layer.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -29,27 +29,29 @@ def test_diffusers_adaLN_lora_layer_get_weight():
2929
rank = 4
3030
alpha = 16.0
3131

32-
lora = LoRALayer(
32+
normal_layer = LoRALayer(
3333
up=torch.ones(out_features, rank),
3434
mid=None,
3535
down=torch.ones(rank, big_in_features),
3636
alpha=alpha,
3737
bias=None
3838
)
39-
layer = DiffusersAdaLN_LoRALayer(
39+
diffuser_adaLN_layer = DiffusersAdaLN_LoRALayer(
4040
up=torch.ones(out_features, rank),
4141
mid=None,
4242
down=torch.ones(rank, big_in_features),
4343
alpha=alpha,
4444
bias=None
4545
)
4646

47+
assert(isinstance(diffuser_adaLN_layer, LoRALayer))
48+
4749
# mock original weight, normally ignored in our loRA
4850
orig_weight = torch.ones(small_in_features)
4951

50-
diffuser_weight = layer.get_weight(orig_weight)
51-
lora_weight = lora.get_weight(orig_weight)
52+
diffuser_weight = diffuser_adaLN_layer.get_weight(orig_weight)
53+
normal_weight = normal_layer.get_weight(orig_weight)
5254

5355
# diffusers lora weight should be flipped
54-
assert(torch.allclose(diffuser_weight, swap_shift_scale(lora_weight)))
56+
assert(torch.allclose(diffuser_weight, swap_shift_scale(normal_weight)))
5557

0 commit comments

Comments
 (0)