@@ -4249,7 +4249,33 @@ def lora_state_dict(
4249
4249
4250
4250
return state_dict
4251
4251
4252
- # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights
4252
+ @classmethod
4253
+ def _maybe_expand_t2v_lora_for_i2v (
4254
+ cls ,
4255
+ transformer : torch .nn .Module ,
4256
+ state_dict ,
4257
+ ):
4258
+ if transformer .config .image_dim is None :
4259
+ return state_dict
4260
+
4261
+ if any (k .startswith ("transformer.blocks." ) for k in state_dict ):
4262
+ num_blocks = len ({k .split ("blocks." )[1 ].split ("." )[0 ] for k in state_dict })
4263
+ is_i2v_lora = any ("add_k_proj" in k for k in state_dict ) and any ("add_v_proj" in k for k in state_dict )
4264
+
4265
+ if is_i2v_lora :
4266
+ return state_dict
4267
+
4268
+ for i in range (num_blocks ):
4269
+ for o , c in zip (["k_img" , "v_img" ], ["add_k_proj" , "add_v_proj" ]):
4270
+ state_dict [f"transformer.blocks.{ i } .attn2.{ c } .lora_A.weight" ] = torch .zeros_like (
4271
+ state_dict [f"transformer.blocks.{ i } .attn2.to_k.lora_A.weight" ]
4272
+ )
4273
+ state_dict [f"transformer.blocks.{ i } .attn2.{ c } .lora_B.weight" ] = torch .zeros_like (
4274
+ state_dict [f"transformer.blocks.{ i } .attn2.to_k.lora_B.weight" ]
4275
+ )
4276
+
4277
+ return state_dict
4278
+
4253
4279
def load_lora_weights (
4254
4280
self , pretrained_model_name_or_path_or_dict : Union [str , Dict [str , torch .Tensor ]], adapter_name = None , ** kwargs
4255
4281
):
@@ -4287,7 +4313,11 @@ def load_lora_weights(
4287
4313
4288
4314
# First, ensure that the checkpoint is a compatible one and can be successfully loaded.
4289
4315
state_dict = self .lora_state_dict (pretrained_model_name_or_path_or_dict , ** kwargs )
4290
-
4316
+ # convert T2V LoRA to I2V LoRA (when loaded to Wan I2V) by adding zeros for the additional (missing) _img layers
4317
+ state_dict = self ._maybe_expand_t2v_lora_for_i2v (
4318
+ transformer = getattr (self , self .transformer_name ) if not hasattr (self , "transformer" ) else self .transformer ,
4319
+ state_dict = state_dict ,
4320
+ )
4291
4321
is_correct_format = all ("lora" in key for key in state_dict .keys ())
4292
4322
if not is_correct_format :
4293
4323
raise ValueError ("Invalid LoRA checkpoint." )
0 commit comments