Skip to content

Commit 00f9273

Browse files
sayakpaula-r-r-o-w
andauthored
[WIP][LoRA] start supporting kijai wan lora. (#11579)
* start supporting kijai wan lora. * diff_b keys. * Apply suggestions from code review Co-authored-by: Aryan <aryan@huggingface.co> * merge ready --------- Co-authored-by: Aryan <aryan@huggingface.co>
1 parent ceb7af2 commit 00f9273

File tree

1 file changed

+93
-10
lines changed

1 file changed

+93
-10
lines changed

src/diffusers/loaders/lora_conversion_utils.py

Lines changed: 93 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1596,48 +1596,131 @@ def _convert_non_diffusers_wan_lora_to_diffusers(state_dict):
15961596
converted_state_dict = {}
15971597
original_state_dict = {k[len("diffusion_model.") :]: v for k, v in state_dict.items()}
15981598

1599-
num_blocks = len({k.split("blocks.")[1].split(".")[0] for k in original_state_dict})
1599+
num_blocks = len({k.split("blocks.")[1].split(".")[0] for k in original_state_dict if "blocks." in k})
16001600
is_i2v_lora = any("k_img" in k for k in original_state_dict) and any("v_img" in k for k in original_state_dict)
1601+
lora_down_key = "lora_A" if any("lora_A" in k for k in original_state_dict) else "lora_down"
1602+
lora_up_key = "lora_B" if any("lora_B" in k for k in original_state_dict) else "lora_up"
1603+
1604+
diff_keys = [k for k in original_state_dict if k.endswith((".diff_b", ".diff"))]
1605+
if diff_keys:
1606+
for diff_k in diff_keys:
1607+
param = original_state_dict[diff_k]
1608+
all_zero = torch.all(param == 0).item()
1609+
if all_zero:
1610+
logger.debug(f"Removed {diff_k} key from the state dict as it's all zeros.")
1611+
original_state_dict.pop(diff_k)
1612+
1613+
# For the `diff_b` keys, we treat them as lora_bias.
1614+
# https://huggingface.co/docs/peft/main/en/package_reference/lora#peft.LoraConfig.lora_bias
16011615

16021616
for i in range(num_blocks):
16031617
# Self-attention
16041618
for o, c in zip(["q", "k", "v", "o"], ["to_q", "to_k", "to_v", "to_out.0"]):
16051619
converted_state_dict[f"blocks.{i}.attn1.{c}.lora_A.weight"] = original_state_dict.pop(
1606-
f"blocks.{i}.self_attn.{o}.lora_A.weight"
1620+
f"blocks.{i}.self_attn.{o}.{lora_down_key}.weight"
16071621
)
16081622
converted_state_dict[f"blocks.{i}.attn1.{c}.lora_B.weight"] = original_state_dict.pop(
1609-
f"blocks.{i}.self_attn.{o}.lora_B.weight"
1623+
f"blocks.{i}.self_attn.{o}.{lora_up_key}.weight"
16101624
)
1625+
if f"blocks.{i}.self_attn.{o}.diff_b" in original_state_dict:
1626+
converted_state_dict[f"blocks.{i}.attn1.{c}.lora_B.bias"] = original_state_dict.pop(
1627+
f"blocks.{i}.self_attn.{o}.diff_b"
1628+
)
16111629

16121630
# Cross-attention
16131631
for o, c in zip(["q", "k", "v", "o"], ["to_q", "to_k", "to_v", "to_out.0"]):
16141632
converted_state_dict[f"blocks.{i}.attn2.{c}.lora_A.weight"] = original_state_dict.pop(
1615-
f"blocks.{i}.cross_attn.{o}.lora_A.weight"
1633+
f"blocks.{i}.cross_attn.{o}.{lora_down_key}.weight"
16161634
)
16171635
converted_state_dict[f"blocks.{i}.attn2.{c}.lora_B.weight"] = original_state_dict.pop(
1618-
f"blocks.{i}.cross_attn.{o}.lora_B.weight"
1636+
f"blocks.{i}.cross_attn.{o}.{lora_up_key}.weight"
16191637
)
1638+
if f"blocks.{i}.cross_attn.{o}.diff_b" in original_state_dict:
1639+
converted_state_dict[f"blocks.{i}.attn2.{c}.lora_B.bias"] = original_state_dict.pop(
1640+
f"blocks.{i}.cross_attn.{o}.diff_b"
1641+
)
16201642

16211643
if is_i2v_lora:
16221644
for o, c in zip(["k_img", "v_img"], ["add_k_proj", "add_v_proj"]):
16231645
converted_state_dict[f"blocks.{i}.attn2.{c}.lora_A.weight"] = original_state_dict.pop(
1624-
f"blocks.{i}.cross_attn.{o}.lora_A.weight"
1646+
f"blocks.{i}.cross_attn.{o}.{lora_down_key}.weight"
16251647
)
16261648
converted_state_dict[f"blocks.{i}.attn2.{c}.lora_B.weight"] = original_state_dict.pop(
1627-
f"blocks.{i}.cross_attn.{o}.lora_B.weight"
1649+
f"blocks.{i}.cross_attn.{o}.{lora_up_key}.weight"
16281650
)
1651+
if f"blocks.{i}.cross_attn.{o}.diff_b" in original_state_dict:
1652+
converted_state_dict[f"blocks.{i}.attn2.{c}.lora_B.bias"] = original_state_dict.pop(
1653+
f"blocks.{i}.cross_attn.{o}.diff_b"
1654+
)
16291655

16301656
# FFN
16311657
for o, c in zip(["ffn.0", "ffn.2"], ["net.0.proj", "net.2"]):
16321658
converted_state_dict[f"blocks.{i}.ffn.{c}.lora_A.weight"] = original_state_dict.pop(
1633-
f"blocks.{i}.{o}.lora_A.weight"
1659+
f"blocks.{i}.{o}.{lora_down_key}.weight"
16341660
)
16351661
converted_state_dict[f"blocks.{i}.ffn.{c}.lora_B.weight"] = original_state_dict.pop(
1636-
f"blocks.{i}.{o}.lora_B.weight"
1662+
f"blocks.{i}.{o}.{lora_up_key}.weight"
16371663
)
1664+
if f"blocks.{i}.{o}.diff_b" in original_state_dict:
1665+
converted_state_dict[f"blocks.{i}.ffn.{c}.lora_B.bias"] = original_state_dict.pop(
1666+
f"blocks.{i}.{o}.diff_b"
1667+
)
1668+
1669+
# Remaining.
1670+
if original_state_dict:
1671+
if any("time_projection" in k for k in original_state_dict):
1672+
converted_state_dict["condition_embedder.time_proj.lora_A.weight"] = original_state_dict.pop(
1673+
f"time_projection.1.{lora_down_key}.weight"
1674+
)
1675+
converted_state_dict["condition_embedder.time_proj.lora_B.weight"] = original_state_dict.pop(
1676+
f"time_projection.1.{lora_up_key}.weight"
1677+
)
1678+
if "time_projection.1.diff_b" in original_state_dict:
1679+
converted_state_dict["condition_embedder.time_proj.lora_B.bias"] = original_state_dict.pop(
1680+
"time_projection.1.diff_b"
1681+
)
1682+
1683+
if any("head.head" in k for k in state_dict):
1684+
converted_state_dict["proj_out.lora_A.weight"] = original_state_dict.pop(
1685+
f"head.head.{lora_down_key}.weight"
1686+
)
1687+
converted_state_dict["proj_out.lora_B.weight"] = original_state_dict.pop(f"head.head.{lora_up_key}.weight")
1688+
if "head.head.diff_b" in original_state_dict:
1689+
converted_state_dict["proj_out.lora_B.bias"] = original_state_dict.pop("head.head.diff_b")
1690+
1691+
for text_time in ["text_embedding", "time_embedding"]:
1692+
if any(text_time in k for k in original_state_dict):
1693+
for b_n in [0, 2]:
1694+
diffusers_b_n = 1 if b_n == 0 else 2
1695+
diffusers_name = (
1696+
"condition_embedder.text_embedder"
1697+
if text_time == "text_embedding"
1698+
else "condition_embedder.time_embedder"
1699+
)
1700+
if any(f"{text_time}.{b_n}" in k for k in original_state_dict):
1701+
converted_state_dict[f"{diffusers_name}.linear_{diffusers_b_n}.lora_A.weight"] = (
1702+
original_state_dict.pop(f"{text_time}.{b_n}.{lora_down_key}.weight")
1703+
)
1704+
converted_state_dict[f"{diffusers_name}.linear_{diffusers_b_n}.lora_B.weight"] = (
1705+
original_state_dict.pop(f"{text_time}.{b_n}.{lora_up_key}.weight")
1706+
)
1707+
if f"{text_time}.{b_n}.diff_b" in original_state_dict:
1708+
converted_state_dict[f"{diffusers_name}.linear_{diffusers_b_n}.lora_B.bias"] = (
1709+
original_state_dict.pop(f"{text_time}.{b_n}.diff_b")
1710+
)
16381711

16391712
if len(original_state_dict) > 0:
1640-
raise ValueError(f"`state_dict` should be empty at this point but has {original_state_dict.keys()=}")
1713+
diff = all(".diff" in k for k in original_state_dict)
1714+
if diff:
1715+
diff_keys = {k for k in original_state_dict if k.endswith(".diff")}
1716+
if not all("lora" not in k for k in diff_keys):
1717+
raise ValueError
1718+
logger.info(
1719+
"The remaining `state_dict` contains `diff` keys which we do not handle yet. If you see performance issues, please file an issue: "
1720+
"https://github.com/huggingface/diffusers//issues/new"
1721+
)
1722+
else:
1723+
raise ValueError(f"`state_dict` should be empty at this point but has {original_state_dict.keys()=}")
16411724

16421725
for key in list(converted_state_dict.keys()):
16431726
converted_state_dict[f"transformer.{key}"] = converted_state_dict.pop(key)

0 commit comments

Comments
 (0)