1
- from typing import Dict
1
+ from typing import Dict , Tuple
2
2
3
3
import torch
4
4
10
10
from invokeai .backend .patches .layers .lokr_layer import LoKRLayer
11
11
from invokeai .backend .patches .layers .lora_layer import LoRALayer
12
12
from invokeai .backend .patches .layers .norm_layer import NormLayer
13
- from invokeai .backend .patches .layers .diffusers_ada_ln_lora_layer import DiffusersAdaLN_LoRALayer
14
13
15
14
16
15
def any_lora_layer_from_state_dict (state_dict : Dict [str , torch .Tensor ]) -> BaseLayerPatch :
@@ -36,8 +35,70 @@ def any_lora_layer_from_state_dict(state_dict: Dict[str, torch.Tensor]) -> BaseL
36
35
raise ValueError (f"Unsupported lora format: { state_dict .keys ()} " )
37
36
38
37
39
- def diffusers_adaLN_lora_layer_from_state_dict (state_dict : Dict [str , torch .Tensor ]) -> DiffusersAdaLN_LoRALayer :
38
+
39
+ def swap_shift_scale_for_linear_weight (weight : torch .Tensor ) -> torch .Tensor :
40
+ """Swap shift/scale for given linear layer back and forth"""
41
+ # In SD3 and Flux implementation of AdaLayerNormContinuous, it split linear projection output into shift, scale;
42
+ # while in diffusers it split into scale, shift. This will flip them around
43
+ chunk1 , chunk2 = weight .chunk (2 , dim = 0 )
44
+ return torch .cat ([chunk2 , chunk1 ], dim = 0 )
45
+
46
+ def decomposite_weight_matric_with_rank (
47
+ delta : torch .Tensor ,
48
+ rank : int ,
49
+ ) -> Tuple [torch .Tensor , torch .Tensor ]:
50
+ """Decompose given matrix with a specified rank."""
51
+ U , S , V = torch .svd (delta )
52
+
53
+ # Truncate to rank r:
54
+ U_r = U [:, :rank ]
55
+ S_r = S [:rank ]
56
+ V_r = V [:, :rank ]
57
+
58
+ S_sqrt = torch .sqrt (S_r )
59
+
60
+ up = torch .matmul (U_r , torch .diag (S_sqrt ))
61
+ down = torch .matmul (torch .diag (S_sqrt ), V_r .T )
62
+
63
+ return up , down
64
+
65
+
66
+ def approximate_flux_adaLN_lora_layer_from_diffusers_state_dict (state_dict : Dict [str , torch .Tensor ]) -> LoRALayer :
67
+ '''Approximate given diffusers AdaLN loRA layer in our Flux model'''
68
+
40
69
if not "lora_up.weight" in state_dict :
41
- raise ValueError (f"Unsupported lora format: { state_dict .keys ()} " )
70
+ raise ValueError (f"Unsupported lora format: { state_dict .keys ()} , missing lora_up " )
42
71
43
- return DiffusersAdaLN_LoRALayer .from_state_dict_values (state_dict )
72
+ if not "lora_down.weight" in state_dict :
73
+ raise ValueError (f"Unsupported lora format: { state_dict .keys ()} , missing lora_down" )
74
+
75
+ up = state_dict .pop ('lora_up.weight' )
76
+ down = state_dict .pop ('lora_down.weight' )
77
+
78
+ dtype = up .dtype
79
+ device = up .device
80
+ up_shape = up .shape
81
+ down_shape = down .shape
82
+
83
+ # desired low rank
84
+ rank = up_shape [1 ]
85
+
86
+ # up scaling for more precise
87
+ up .double ()
88
+ down .double ()
89
+ weight = up .reshape (up .shape [0 ], - 1 ) @ down .reshape (down .shape [0 ], - 1 )
90
+
91
+ # swap to our linear format
92
+ swapped = swap_shift_scale_for_linear_weight (weight )
93
+
94
+ _up , _down = decomposite_weight_matric_with_rank (swapped , rank )
95
+
96
+ assert (_up .shape == up_shape )
97
+ assert (_down .shape == down_shape )
98
+
99
+ # down scaling to original dtype, device
100
+ state_dict ['lora_up.weight' ] = _up .to (dtype ).to (device = device )
101
+ state_dict ['lora_down.weight' ] = _down .to (dtype ).to (device = device )
102
+
103
+ return LoRALayer .from_state_dict_values (state_dict )
104
+
0 commit comments