@@ -1596,48 +1596,131 @@ def _convert_non_diffusers_wan_lora_to_diffusers(state_dict):
1596
1596
converted_state_dict = {}
1597
1597
original_state_dict = {k [len ("diffusion_model." ) :]: v for k , v in state_dict .items ()}
1598
1598
1599
- num_blocks = len ({k .split ("blocks." )[1 ].split ("." )[0 ] for k in original_state_dict })
1599
+ num_blocks = len ({k .split ("blocks." )[1 ].split ("." )[0 ] for k in original_state_dict if "blocks." in k })
1600
1600
is_i2v_lora = any ("k_img" in k for k in original_state_dict ) and any ("v_img" in k for k in original_state_dict )
1601
+ lora_down_key = "lora_A" if any ("lora_A" in k for k in original_state_dict ) else "lora_down"
1602
+ lora_up_key = "lora_B" if any ("lora_B" in k for k in original_state_dict ) else "lora_up"
1603
+
1604
+ diff_keys = [k for k in original_state_dict if k .endswith ((".diff_b" , ".diff" ))]
1605
+ if diff_keys :
1606
+ for diff_k in diff_keys :
1607
+ param = original_state_dict [diff_k ]
1608
+ all_zero = torch .all (param == 0 ).item ()
1609
+ if all_zero :
1610
+ logger .debug (f"Removed { diff_k } key from the state dict as it's all zeros." )
1611
+ original_state_dict .pop (diff_k )
1612
+
1613
+ # For the `diff_b` keys, we treat them as lora_bias.
1614
+ # https://huggingface.co/docs/peft/main/en/package_reference/lora#peft.LoraConfig.lora_bias
1601
1615
1602
1616
for i in range (num_blocks ):
1603
1617
# Self-attention
1604
1618
for o , c in zip (["q" , "k" , "v" , "o" ], ["to_q" , "to_k" , "to_v" , "to_out.0" ]):
1605
1619
converted_state_dict [f"blocks.{ i } .attn1.{ c } .lora_A.weight" ] = original_state_dict .pop (
1606
- f"blocks.{ i } .self_attn.{ o } .lora_A .weight"
1620
+ f"blocks.{ i } .self_attn.{ o } .{ lora_down_key } .weight"
1607
1621
)
1608
1622
converted_state_dict [f"blocks.{ i } .attn1.{ c } .lora_B.weight" ] = original_state_dict .pop (
1609
- f"blocks.{ i } .self_attn.{ o } .lora_B .weight"
1623
+ f"blocks.{ i } .self_attn.{ o } .{ lora_up_key } .weight"
1610
1624
)
1625
+ if f"blocks.{ i } .self_attn.{ o } .diff_b" in original_state_dict :
1626
+ converted_state_dict [f"blocks.{ i } .attn1.{ c } .lora_B.bias" ] = original_state_dict .pop (
1627
+ f"blocks.{ i } .self_attn.{ o } .diff_b"
1628
+ )
1611
1629
1612
1630
# Cross-attention
1613
1631
for o , c in zip (["q" , "k" , "v" , "o" ], ["to_q" , "to_k" , "to_v" , "to_out.0" ]):
1614
1632
converted_state_dict [f"blocks.{ i } .attn2.{ c } .lora_A.weight" ] = original_state_dict .pop (
1615
- f"blocks.{ i } .cross_attn.{ o } .lora_A .weight"
1633
+ f"blocks.{ i } .cross_attn.{ o } .{ lora_down_key } .weight"
1616
1634
)
1617
1635
converted_state_dict [f"blocks.{ i } .attn2.{ c } .lora_B.weight" ] = original_state_dict .pop (
1618
- f"blocks.{ i } .cross_attn.{ o } .lora_B .weight"
1636
+ f"blocks.{ i } .cross_attn.{ o } .{ lora_up_key } .weight"
1619
1637
)
1638
+ if f"blocks.{ i } .cross_attn.{ o } .diff_b" in original_state_dict :
1639
+ converted_state_dict [f"blocks.{ i } .attn2.{ c } .lora_B.bias" ] = original_state_dict .pop (
1640
+ f"blocks.{ i } .cross_attn.{ o } .diff_b"
1641
+ )
1620
1642
1621
1643
if is_i2v_lora :
1622
1644
for o , c in zip (["k_img" , "v_img" ], ["add_k_proj" , "add_v_proj" ]):
1623
1645
converted_state_dict [f"blocks.{ i } .attn2.{ c } .lora_A.weight" ] = original_state_dict .pop (
1624
- f"blocks.{ i } .cross_attn.{ o } .lora_A .weight"
1646
+ f"blocks.{ i } .cross_attn.{ o } .{ lora_down_key } .weight"
1625
1647
)
1626
1648
converted_state_dict [f"blocks.{ i } .attn2.{ c } .lora_B.weight" ] = original_state_dict .pop (
1627
- f"blocks.{ i } .cross_attn.{ o } .lora_B .weight"
1649
+ f"blocks.{ i } .cross_attn.{ o } .{ lora_up_key } .weight"
1628
1650
)
1651
+ if f"blocks.{ i } .cross_attn.{ o } .diff_b" in original_state_dict :
1652
+ converted_state_dict [f"blocks.{ i } .attn2.{ c } .lora_B.bias" ] = original_state_dict .pop (
1653
+ f"blocks.{ i } .cross_attn.{ o } .diff_b"
1654
+ )
1629
1655
1630
1656
# FFN
1631
1657
for o , c in zip (["ffn.0" , "ffn.2" ], ["net.0.proj" , "net.2" ]):
1632
1658
converted_state_dict [f"blocks.{ i } .ffn.{ c } .lora_A.weight" ] = original_state_dict .pop (
1633
- f"blocks.{ i } .{ o } .lora_A .weight"
1659
+ f"blocks.{ i } .{ o } .{ lora_down_key } .weight"
1634
1660
)
1635
1661
converted_state_dict [f"blocks.{ i } .ffn.{ c } .lora_B.weight" ] = original_state_dict .pop (
1636
- f"blocks.{ i } .{ o } .lora_B .weight"
1662
+ f"blocks.{ i } .{ o } .{ lora_up_key } .weight"
1637
1663
)
1664
+ if f"blocks.{ i } .{ o } .diff_b" in original_state_dict :
1665
+ converted_state_dict [f"blocks.{ i } .ffn.{ c } .lora_B.bias" ] = original_state_dict .pop (
1666
+ f"blocks.{ i } .{ o } .diff_b"
1667
+ )
1668
+
1669
+ # Remaining.
1670
+ if original_state_dict :
1671
+ if any ("time_projection" in k for k in original_state_dict ):
1672
+ converted_state_dict ["condition_embedder.time_proj.lora_A.weight" ] = original_state_dict .pop (
1673
+ f"time_projection.1.{ lora_down_key } .weight"
1674
+ )
1675
+ converted_state_dict ["condition_embedder.time_proj.lora_B.weight" ] = original_state_dict .pop (
1676
+ f"time_projection.1.{ lora_up_key } .weight"
1677
+ )
1678
+ if "time_projection.1.diff_b" in original_state_dict :
1679
+ converted_state_dict ["condition_embedder.time_proj.lora_B.bias" ] = original_state_dict .pop (
1680
+ "time_projection.1.diff_b"
1681
+ )
1682
+
1683
+ if any ("head.head" in k for k in state_dict ):
1684
+ converted_state_dict ["proj_out.lora_A.weight" ] = original_state_dict .pop (
1685
+ f"head.head.{ lora_down_key } .weight"
1686
+ )
1687
+ converted_state_dict ["proj_out.lora_B.weight" ] = original_state_dict .pop (f"head.head.{ lora_up_key } .weight" )
1688
+ if "head.head.diff_b" in original_state_dict :
1689
+ converted_state_dict ["proj_out.lora_B.bias" ] = original_state_dict .pop ("head.head.diff_b" )
1690
+
1691
+ for text_time in ["text_embedding" , "time_embedding" ]:
1692
+ if any (text_time in k for k in original_state_dict ):
1693
+ for b_n in [0 , 2 ]:
1694
+ diffusers_b_n = 1 if b_n == 0 else 2
1695
+ diffusers_name = (
1696
+ "condition_embedder.text_embedder"
1697
+ if text_time == "text_embedding"
1698
+ else "condition_embedder.time_embedder"
1699
+ )
1700
+ if any (f"{ text_time } .{ b_n } " in k for k in original_state_dict ):
1701
+ converted_state_dict [f"{ diffusers_name } .linear_{ diffusers_b_n } .lora_A.weight" ] = (
1702
+ original_state_dict .pop (f"{ text_time } .{ b_n } .{ lora_down_key } .weight" )
1703
+ )
1704
+ converted_state_dict [f"{ diffusers_name } .linear_{ diffusers_b_n } .lora_B.weight" ] = (
1705
+ original_state_dict .pop (f"{ text_time } .{ b_n } .{ lora_up_key } .weight" )
1706
+ )
1707
+ if f"{ text_time } .{ b_n } .diff_b" in original_state_dict :
1708
+ converted_state_dict [f"{ diffusers_name } .linear_{ diffusers_b_n } .lora_B.bias" ] = (
1709
+ original_state_dict .pop (f"{ text_time } .{ b_n } .diff_b" )
1710
+ )
1638
1711
1639
1712
if len (original_state_dict ) > 0 :
1640
- raise ValueError (f"`state_dict` should be empty at this point but has { original_state_dict .keys ()= } " )
1713
+ diff = all (".diff" in k for k in original_state_dict )
1714
+ if diff :
1715
+ diff_keys = {k for k in original_state_dict if k .endswith (".diff" )}
1716
+ if not all ("lora" not in k for k in diff_keys ):
1717
+ raise ValueError
1718
+ logger .info (
1719
+ "The remaining `state_dict` contains `diff` keys which we do not handle yet. If you see performance issues, please file an issue: "
1720
+ "https://github.com/huggingface/diffusers//issues/new"
1721
+ )
1722
+ else :
1723
+ raise ValueError (f"`state_dict` should be empty at this point but has { original_state_dict .keys ()= } " )
1641
1724
1642
1725
for key in list (converted_state_dict .keys ()):
1643
1726
converted_state_dict [f"transformer.{ key } " ] = converted_state_dict .pop (key )
0 commit comments