@@ -82,6 +82,19 @@ def add_lora_layer_if_present(src_key: str, dst_key: str) -> None:
82
82
values = get_lora_layer_values (src_layer_dict )
83
83
layers [dst_key ] = any_lora_layer_from_state_dict (values )
84
84
85
+ def add_lora_adaLN_layer_if_present (src_key : str , dst_key : str ) -> None :
86
+ if src_key in grouped_state_dict :
87
+ src_layer_dict = grouped_state_dict .pop (src_key )
88
+ values = get_lora_layer_values (src_layer_dict )
89
+
90
+ for _key in values .keys ():
91
+ # in SD3 original implementation of AdaLayerNormContinuous, it split linear projection output into shift, scale;
92
+ # while in diffusers it split into scale, shift. Here we swap the linear projection weights in order to be able to use diffusers implementation
93
+ scale , shift = values [_key ].chunk (2 , dim = 0 )
94
+ values [_key ] = torch .cat ([shift , scale ], dim = 0 )
95
+
96
+ layers [dst_key ] = any_lora_layer_from_state_dict (values )
97
+
85
98
def add_qkv_lora_layer_if_present (
86
99
src_keys : list [str ],
87
100
src_weight_shapes : list [tuple [int , int ]],
@@ -223,6 +236,10 @@ def add_qkv_lora_layer_if_present(
223
236
224
237
# Final layer.
225
238
add_lora_layer_if_present ("proj_out" , "final_layer.linear" )
239
+ add_lora_adaLN_layer_if_present (
240
+ 'norm_out.linear' ,
241
+ 'final_layer.adaLN_modulation.1' ,
242
+ )
226
243
227
244
# Assert that all keys were processed.
228
245
assert len (grouped_state_dict ) == 0
0 commit comments