75
75
"stable_cascade_stage_b" : "down_blocks.1.0.channelwise.0.weight" ,
76
76
"stable_cascade_stage_c" : "clip_txt_mapper.weight" ,
77
77
"sd3" : "model.diffusion_model.joint_blocks.0.context_block.adaLN_modulation.1.bias" ,
78
+ "sd35_large" : "model.diffusion_model.joint_blocks.37.x_block.mlp.fc1.weight" ,
78
79
"animatediff" : "down_blocks.0.motion_modules.0.temporal_transformer.transformer_blocks.0.attention_blocks.0.pos_encoder.pe" ,
79
80
"animatediff_v2" : "mid_block.motion_modules.0.temporal_transformer.norm.bias" ,
80
81
"animatediff_sdxl_beta" : "up_blocks.2.motion_modules.0.temporal_transformer.norm.weight" ,
113
114
"sd3" : {
114
115
"pretrained_model_name_or_path" : "stabilityai/stable-diffusion-3-medium-diffusers" ,
115
116
},
117
+ "sd35_large" : {
118
+ "pretrained_model_name_or_path" : "stabilityai/stable-diffusion-3.5-large" ,
119
+ },
116
120
"animatediff_v1" : {"pretrained_model_name_or_path" : "guoyww/animatediff-motion-adapter-v1-5" },
117
121
"animatediff_v2" : {"pretrained_model_name_or_path" : "guoyww/animatediff-motion-adapter-v1-5-2" },
118
122
"animatediff_v3" : {"pretrained_model_name_or_path" : "guoyww/animatediff-motion-adapter-v1-5-3" },
@@ -504,9 +508,12 @@ def infer_diffusers_model_type(checkpoint):
504
508
):
505
509
model_type = "stable_cascade_stage_b"
506
510
507
- elif CHECKPOINT_KEY_NAMES ["sd3" ] in checkpoint :
511
+ elif CHECKPOINT_KEY_NAMES ["sd3" ] in checkpoint and checkpoint [ CHECKPOINT_KEY_NAMES [ "sd3" ]]. shape [ - 1 ] == 9216 :
508
512
model_type = "sd3"
509
513
514
+ elif CHECKPOINT_KEY_NAMES ["sd35_large" ] in checkpoint :
515
+ model_type = "sd35_large"
516
+
510
517
elif CHECKPOINT_KEY_NAMES ["animatediff" ] in checkpoint :
511
518
if CHECKPOINT_KEY_NAMES ["animatediff_scribble" ] in checkpoint :
512
519
model_type = "animatediff_scribble"
@@ -1670,6 +1677,22 @@ def swap_scale_shift(weight, dim):
1670
1677
return new_weight
1671
1678
1672
1679
1680
+ def get_attn2_layers (state_dict ):
1681
+ attn2_layers = []
1682
+ for key in state_dict .keys ():
1683
+ if "attn2." in key :
1684
+ # Extract the layer number from the key
1685
+ layer_num = int (key .split ("." )[1 ])
1686
+ attn2_layers .append (layer_num )
1687
+
1688
+ return tuple (sorted (set (attn2_layers )))
1689
+
1690
+
1691
+ def get_caption_projection_dim (state_dict ):
1692
+ caption_projection_dim = state_dict ["context_embedder.weight" ].shape [0 ]
1693
+ return caption_projection_dim
1694
+
1695
+
1673
1696
def convert_sd3_transformer_checkpoint_to_diffusers (checkpoint , ** kwargs ):
1674
1697
converted_state_dict = {}
1675
1698
keys = list (checkpoint .keys ())
@@ -1678,7 +1701,10 @@ def convert_sd3_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
1678
1701
checkpoint [k .replace ("model.diffusion_model." , "" )] = checkpoint .pop (k )
1679
1702
1680
1703
num_layers = list (set (int (k .split ("." , 2 )[1 ]) for k in checkpoint if "joint_blocks" in k ))[- 1 ] + 1 # noqa: C401
1681
- caption_projection_dim = 1536
1704
+ dual_attention_layers = get_attn2_layers (checkpoint )
1705
+
1706
+ caption_projection_dim = get_caption_projection_dim (checkpoint )
1707
+ has_qk_norm = any ("ln_q" in key for key in checkpoint .keys ())
1682
1708
1683
1709
# Positional and patch embeddings.
1684
1710
converted_state_dict ["pos_embed.pos_embed" ] = checkpoint .pop ("pos_embed" )
@@ -1735,6 +1761,21 @@ def convert_sd3_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
1735
1761
converted_state_dict [f"transformer_blocks.{ i } .attn.add_v_proj.weight" ] = torch .cat ([context_v ])
1736
1762
converted_state_dict [f"transformer_blocks.{ i } .attn.add_v_proj.bias" ] = torch .cat ([context_v_bias ])
1737
1763
1764
+ # qk norm
1765
+ if has_qk_norm :
1766
+ converted_state_dict [f"transformer_blocks.{ i } .attn.norm_q.weight" ] = checkpoint .pop (
1767
+ f"joint_blocks.{ i } .x_block.attn.ln_q.weight"
1768
+ )
1769
+ converted_state_dict [f"transformer_blocks.{ i } .attn.norm_k.weight" ] = checkpoint .pop (
1770
+ f"joint_blocks.{ i } .x_block.attn.ln_k.weight"
1771
+ )
1772
+ converted_state_dict [f"transformer_blocks.{ i } .attn.norm_added_q.weight" ] = checkpoint .pop (
1773
+ f"joint_blocks.{ i } .context_block.attn.ln_q.weight"
1774
+ )
1775
+ converted_state_dict [f"transformer_blocks.{ i } .attn.norm_added_k.weight" ] = checkpoint .pop (
1776
+ f"joint_blocks.{ i } .context_block.attn.ln_k.weight"
1777
+ )
1778
+
1738
1779
# output projections.
1739
1780
converted_state_dict [f"transformer_blocks.{ i } .attn.to_out.0.weight" ] = checkpoint .pop (
1740
1781
f"joint_blocks.{ i } .x_block.attn.proj.weight"
@@ -1750,6 +1791,38 @@ def convert_sd3_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
1750
1791
f"joint_blocks.{ i } .context_block.attn.proj.bias"
1751
1792
)
1752
1793
1794
+ if i in dual_attention_layers :
1795
+ # Q, K, V
1796
+ sample_q2 , sample_k2 , sample_v2 = torch .chunk (
1797
+ checkpoint .pop (f"joint_blocks.{ i } .x_block.attn2.qkv.weight" ), 3 , dim = 0
1798
+ )
1799
+ sample_q2_bias , sample_k2_bias , sample_v2_bias = torch .chunk (
1800
+ checkpoint .pop (f"joint_blocks.{ i } .x_block.attn2.qkv.bias" ), 3 , dim = 0
1801
+ )
1802
+ converted_state_dict [f"transformer_blocks.{ i } .attn2.to_q.weight" ] = torch .cat ([sample_q2 ])
1803
+ converted_state_dict [f"transformer_blocks.{ i } .attn2.to_q.bias" ] = torch .cat ([sample_q2_bias ])
1804
+ converted_state_dict [f"transformer_blocks.{ i } .attn2.to_k.weight" ] = torch .cat ([sample_k2 ])
1805
+ converted_state_dict [f"transformer_blocks.{ i } .attn2.to_k.bias" ] = torch .cat ([sample_k2_bias ])
1806
+ converted_state_dict [f"transformer_blocks.{ i } .attn2.to_v.weight" ] = torch .cat ([sample_v2 ])
1807
+ converted_state_dict [f"transformer_blocks.{ i } .attn2.to_v.bias" ] = torch .cat ([sample_v2_bias ])
1808
+
1809
+ # qk norm
1810
+ if has_qk_norm :
1811
+ converted_state_dict [f"transformer_blocks.{ i } .attn2.norm_q.weight" ] = checkpoint .pop (
1812
+ f"joint_blocks.{ i } .x_block.attn2.ln_q.weight"
1813
+ )
1814
+ converted_state_dict [f"transformer_blocks.{ i } .attn2.norm_k.weight" ] = checkpoint .pop (
1815
+ f"joint_blocks.{ i } .x_block.attn2.ln_k.weight"
1816
+ )
1817
+
1818
+ # output projections.
1819
+ converted_state_dict [f"transformer_blocks.{ i } .attn2.to_out.0.weight" ] = checkpoint .pop (
1820
+ f"joint_blocks.{ i } .x_block.attn2.proj.weight"
1821
+ )
1822
+ converted_state_dict [f"transformer_blocks.{ i } .attn2.to_out.0.bias" ] = checkpoint .pop (
1823
+ f"joint_blocks.{ i } .x_block.attn2.proj.bias"
1824
+ )
1825
+
1753
1826
# norms.
1754
1827
converted_state_dict [f"transformer_blocks.{ i } .norm1.linear.weight" ] = checkpoint .pop (
1755
1828
f"joint_blocks.{ i } .x_block.adaLN_modulation.1.weight"
0 commit comments