Skip to content

Commit 4f1b6ce

Browse files
feat: add new layer type for diffusers-ada-ln
1 parent c12005e commit 4f1b6ce

File tree

3 files changed

+27
-10
lines changed

3 files changed

+27
-10
lines changed
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
import torch
2+
3+
from invokeai.backend.patches.layers.lora_layer import LoRALayer
4+
5+
class DiffusersAdaLN_LoRALayer(LoRALayer):
6+
'''LoRA layer converted from Diffusers AdaLN, weight is shift-scale swapped'''
7+
8+
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
9+
# In SD3 and Flux implementation of AdaLayerNormContinuous, it split linear projection output into shift, scale;
10+
# while in diffusers it split into scale, shift.
11+
# So we swap the linear projection weights in order to be able to use Flux implementation
12+
13+
weight = super().get_weight(orig_weight)
14+
scale, shift = weight.chunk(2, dim=0)
15+
16+
return torch.cat([shift, scale], dim=0)

invokeai/backend/patches/layers/utils.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from invokeai.backend.patches.layers.lokr_layer import LoKRLayer
1111
from invokeai.backend.patches.layers.lora_layer import LoRALayer
1212
from invokeai.backend.patches.layers.norm_layer import NormLayer
13+
from invokeai.backend.patches.layers.diffusers_ada_ln_lora_layer import DiffusersAdaLN_LoRALayer
1314

1415

1516
def any_lora_layer_from_state_dict(state_dict: Dict[str, torch.Tensor]) -> BaseLayerPatch:
@@ -33,3 +34,10 @@ def any_lora_layer_from_state_dict(state_dict: Dict[str, torch.Tensor]) -> BaseL
3334
return NormLayer.from_state_dict_values(state_dict)
3435
else:
3536
raise ValueError(f"Unsupported lora format: {state_dict.keys()}")
37+
38+
39+
def diffusers_adaLN_lora_layer_from_state_dict(state_dict: Dict[str, torch.Tensor]) -> DiffusersAdaLN_LoRALayer:
40+
if not "lora_up.weight" in state_dict:
41+
raise ValueError(f"Unsupported lora format: {state_dict.keys()}")
42+
43+
return DiffusersAdaLN_LoRALayer.from_state_dict_values(state_dict)

invokeai/backend/patches/lora_conversions/flux_diffusers_lora_conversion_utils.py

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
from invokeai.backend.patches.layers.base_layer_patch import BaseLayerPatch
66
from invokeai.backend.patches.layers.merged_layer_patch import MergedLayerPatch, Range
7-
from invokeai.backend.patches.layers.utils import any_lora_layer_from_state_dict
7+
from invokeai.backend.patches.layers.utils import any_lora_layer_from_state_dict, diffusers_adaLN_lora_layer_from_state_dict
88
from invokeai.backend.patches.lora_conversions.flux_lora_constants import FLUX_LORA_TRANSFORMER_PREFIX
99
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
1010

@@ -86,15 +86,8 @@ def add_adaLN_lora_layer_if_present(src_key: str, dst_key: str) -> None:
8686
if src_key in grouped_state_dict:
8787
src_layer_dict = grouped_state_dict.pop(src_key)
8888
values = get_lora_layer_values(src_layer_dict)
89-
90-
for _key in values.keys():
91-
# in SD3 original implementation of AdaLayerNormContinuous, it split linear projection output into shift, scale;
92-
# while in diffusers it split into scale, shift. Here we swap the linear projection weights in order to be able to use diffusers implementation
93-
scale, shift = values[_key].chunk(2, dim=0)
94-
values[_key] = torch.cat([shift, scale], dim=0)
95-
96-
layers[dst_key] = any_lora_layer_from_state_dict(values)
97-
89+
layers[dst_key] = diffusers_adaLN_lora_layer_from_state_dict(values)
90+
9891
def add_qkv_lora_layer_if_present(
9992
src_keys: list[str],
10093
src_weight_shapes: list[tuple[int, int]],

0 commit comments

Comments
 (0)