-
Notifications
You must be signed in to change notification settings - Fork 6.1k
[Wan LoRAs] make T2V LoRAs compatible with Wan I2V #11107
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 13 commits
f351017
5e6a15b
ccdc4fd
b5689c4
fe2d3b4
6637a12
63e581c
9fa3d93
051f534
3834c16
292d618
440001c
be5a01a
92aabcb
51c570d
f5b5986
0f3a48f
d2dd6ae
c464455
86cbc0f
6c39465
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4249,7 +4249,32 @@ def lora_state_dict( | |
|
||
return state_dict | ||
|
||
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights | ||
@classmethod | ||
def _maybe_expand_t2v_lora_for_i2v( | ||
cls, | ||
transformer: torch.nn.Module, | ||
state_dict, | ||
): | ||
if any(k.startswith("blocks.") for k in state_dict): | ||
num_blocks = len({k.split("blocks.")[1].split(".")[0] for k in state_dict}) | ||
is_i2v_lora = any("k_img" in k for k in state_dict) and any("v_img" in k for k in state_dict) | ||
if not is_i2v_lora: | ||
linoytsaban marked this conversation as resolved.
Show resolved
Hide resolved
|
||
return state_dict | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We don't perform any extra operation if it's There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is for loading T2V lora into I2V model, so if it's already I2V lora we return the state dict as-is. |
||
|
||
if transformer.config.image_dim is None: | ||
return state_dict | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This could be moved out at the top of this function. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Might be slightly faster than checking the keys first, this is checking whether the transformer is I2V. T2V transformer config has There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this should be There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. T2V has We have T2V loaded -> we are loading T2V lora -> We have I2V loaded -> we are loading T2V lora -> There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. you're absolutely right🙌🏻 |
||
|
||
for i in range(num_blocks): | ||
for o, c in zip(["k_img", "v_img"], ["add_k_proj", "add_v_proj"]): | ||
state_dict[f"blocks.{i}.attn2.{c}.lora_A.weight"] = torch.zeros_like( | ||
state_dict[f"blocks.{i}.attn2.{o.replace('_img', '')}.lora_A.weight"] | ||
) | ||
state_dict[f"blocks.{i}.attn2.{c}.lora_B.weight"] = torch.zeros_like( | ||
state_dict[f"blocks.{i}.attn2.{o.replace('_img', '')}.lora_B.weight"] | ||
) | ||
linoytsaban marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
return state_dict | ||
|
||
def load_lora_weights( | ||
self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs | ||
): | ||
|
@@ -4287,7 +4312,11 @@ def load_lora_weights( | |
|
||
# First, ensure that the checkpoint is a compatible one and can be successfully loaded. | ||
state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) | ||
|
||
# convert T2V LoRA to I2V LoRA (when loaded to Wan I2V) by adding zeros for the additional (missing) _img layers | ||
state_dict = self._maybe_expand_t2v_lora_for_i2v( | ||
transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer, | ||
state_dict=state_dict, | ||
) | ||
is_correct_format = all("lora" in key for key in state_dict.keys()) | ||
if not is_correct_format: | ||
raise ValueError("Invalid LoRA checkpoint.") | ||
|
Uh oh!
There was an error while loading. Please reload this page.