Skip to content

Commit a34d97c

Browse files
linoytsabangithub-actions[bot]Linoyhlky
authored
[Wan LoRAs] make T2V LoRAs compatible with Wan I2V (#11107)
* @hlky t2v->i2v * Apply style fixes * try with ones to not nullify layers * fix method name * revert to zeros * add check to state_dict keys * add comment * copies fix * Revert "copies fix" This reverts commit 051f534. * remove copied from * Update src/diffusers/loaders/lora_pipeline.py Co-authored-by: hlky <hlky@hlky.ac> * Update src/diffusers/loaders/lora_pipeline.py Co-authored-by: hlky <hlky@hlky.ac> * update * update * Update src/diffusers/loaders/lora_pipeline.py Co-authored-by: hlky <hlky@hlky.ac> * Apply style fixes --------- Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com> Co-authored-by: Linoy <linoy@hf.co> Co-authored-by: hlky <hlky@hlky.ac>
1 parent fc28791 commit a34d97c

File tree

1 file changed

+32
-2
lines changed

1 file changed

+32
-2
lines changed

src/diffusers/loaders/lora_pipeline.py

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4249,7 +4249,33 @@ def lora_state_dict(
42494249

42504250
return state_dict
42514251

4252-
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights
4252+
@classmethod
4253+
def _maybe_expand_t2v_lora_for_i2v(
4254+
cls,
4255+
transformer: torch.nn.Module,
4256+
state_dict,
4257+
):
4258+
if transformer.config.image_dim is None:
4259+
return state_dict
4260+
4261+
if any(k.startswith("transformer.blocks.") for k in state_dict):
4262+
num_blocks = len({k.split("blocks.")[1].split(".")[0] for k in state_dict})
4263+
is_i2v_lora = any("add_k_proj" in k for k in state_dict) and any("add_v_proj" in k for k in state_dict)
4264+
4265+
if is_i2v_lora:
4266+
return state_dict
4267+
4268+
for i in range(num_blocks):
4269+
for o, c in zip(["k_img", "v_img"], ["add_k_proj", "add_v_proj"]):
4270+
state_dict[f"transformer.blocks.{i}.attn2.{c}.lora_A.weight"] = torch.zeros_like(
4271+
state_dict[f"transformer.blocks.{i}.attn2.to_k.lora_A.weight"]
4272+
)
4273+
state_dict[f"transformer.blocks.{i}.attn2.{c}.lora_B.weight"] = torch.zeros_like(
4274+
state_dict[f"transformer.blocks.{i}.attn2.to_k.lora_B.weight"]
4275+
)
4276+
4277+
return state_dict
4278+
42534279
def load_lora_weights(
42544280
self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs
42554281
):
@@ -4287,7 +4313,11 @@ def load_lora_weights(
42874313

42884314
# First, ensure that the checkpoint is a compatible one and can be successfully loaded.
42894315
state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
4290-
4316+
# convert T2V LoRA to I2V LoRA (when loaded to Wan I2V) by adding zeros for the additional (missing) _img layers
4317+
state_dict = self._maybe_expand_t2v_lora_for_i2v(
4318+
transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
4319+
state_dict=state_dict,
4320+
)
42914321
is_correct_format = all("lora" in key for key in state_dict.keys())
42924322
if not is_correct_format:
42934323
raise ValueError("Invalid LoRA checkpoint.")

0 commit comments

Comments
 (0)