Skip to content

Commit 0021bfa

Browse files
yiyixuxua-r-r-o-w
andauthored
support Wan-FLF2V (#11353)
* update transformer --------- Co-authored-by: Aryan <aryan@huggingface.co>
1 parent bbd0c16 commit 0021bfa

File tree

5 files changed

+226
-10
lines changed

5 files changed

+226
-10
lines changed

docs/source/en/api/pipelines/wan.md

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,60 @@ output = pipe(
133133
export_to_video(output, "wan-i2v.mp4", fps=16)
134134
```
135135

136+
### First and Last Frame Interpolation
137+
138+
```python
139+
import numpy as np
140+
import torch
141+
import torchvision.transforms.functional as TF
142+
from diffusers import AutoencoderKLWan, WanImageToVideoPipeline
143+
from diffusers.utils import export_to_video, load_image
144+
from transformers import CLIPVisionModel
145+
146+
147+
model_id = "Wan-AI/Wan2.1-FLF2V-14B-720P-diffusers"
148+
image_encoder = CLIPVisionModel.from_pretrained(model_id, subfolder="image_encoder", torch_dtype=torch.float32)
149+
vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
150+
pipe = WanImageToVideoPipeline.from_pretrained(
151+
model_id, vae=vae, image_encoder=image_encoder, torch_dtype=torch.bfloat16
152+
)
153+
pipe.to("cuda")
154+
155+
first_frame = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/flf2v_input_first_frame.png")
156+
last_frame = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/flf2v_input_last_frame.png")
157+
158+
def aspect_ratio_resize(image, pipe, max_area=720 * 1280):
159+
aspect_ratio = image.height / image.width
160+
mod_value = pipe.vae_scale_factor_spatial * pipe.transformer.config.patch_size[1]
161+
height = round(np.sqrt(max_area * aspect_ratio)) // mod_value * mod_value
162+
width = round(np.sqrt(max_area / aspect_ratio)) // mod_value * mod_value
163+
image = image.resize((width, height))
164+
return image, height, width
165+
166+
def center_crop_resize(image, height, width):
167+
# Calculate resize ratio to match first frame dimensions
168+
resize_ratio = max(width / image.width, height / image.height)
169+
170+
# Resize the image
171+
width = round(image.width * resize_ratio)
172+
height = round(image.height * resize_ratio)
173+
size = [width, height]
174+
image = TF.center_crop(image, size)
175+
176+
return image, height, width
177+
178+
first_frame, height, width = aspect_ratio_resize(first_frame, pipe)
179+
if last_frame.size != first_frame.size:
180+
last_frame, _, _ = center_crop_resize(last_frame, height, width)
181+
182+
prompt = "CG animation style, a small blue bird takes off from the ground, flapping its wings. The bird's feathers are delicate, with a unique pattern on its chest. The background shows a blue sky with white clouds under bright sunshine. The camera follows the bird upward, capturing its flight and the vastness of the sky from a close-up, low-angle perspective."
183+
184+
output = pipe(
185+
image=first_frame, last_image=last_frame, prompt=prompt, height=height, width=width, guidance_scale=5.5
186+
).frames[0]
187+
export_to_video(output, "output.mp4", fps=16)
188+
```
189+
136190
### Video to Video Generation
137191

138192
```python

scripts/convert_wan_to_diffusers.py

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,24 @@
3939
"img_emb.proj.1": "condition_embedder.image_embedder.ff.net.0.proj",
4040
"img_emb.proj.3": "condition_embedder.image_embedder.ff.net.2",
4141
"img_emb.proj.4": "condition_embedder.image_embedder.norm2",
42+
# for the FLF2V model
43+
"img_emb.emb_pos": "condition_embedder.image_embedder.pos_embed",
44+
# Add attention component mappings
45+
"self_attn.q": "attn1.to_q",
46+
"self_attn.k": "attn1.to_k",
47+
"self_attn.v": "attn1.to_v",
48+
"self_attn.o": "attn1.to_out.0",
49+
"self_attn.norm_q": "attn1.norm_q",
50+
"self_attn.norm_k": "attn1.norm_k",
51+
"cross_attn.q": "attn2.to_q",
52+
"cross_attn.k": "attn2.to_k",
53+
"cross_attn.v": "attn2.to_v",
54+
"cross_attn.o": "attn2.to_out.0",
55+
"cross_attn.norm_q": "attn2.norm_q",
56+
"cross_attn.norm_k": "attn2.norm_k",
57+
"attn2.to_k_img": "attn2.add_k_proj",
58+
"attn2.to_v_img": "attn2.add_v_proj",
59+
"attn2.norm_k_img": "attn2.norm_added_k",
4260
}
4361

4462
TRANSFORMER_SPECIAL_KEYS_REMAP = {}
@@ -135,6 +153,28 @@ def get_transformer_config(model_type: str) -> Dict[str, Any]:
135153
"text_dim": 4096,
136154
},
137155
}
156+
elif model_type == "Wan-FLF2V-14B-720P":
157+
config = {
158+
"model_id": "ypyp/Wan2.1-FLF2V-14B-720P", # This is just a placeholder
159+
"diffusers_config": {
160+
"image_dim": 1280,
161+
"added_kv_proj_dim": 5120,
162+
"attention_head_dim": 128,
163+
"cross_attn_norm": True,
164+
"eps": 1e-06,
165+
"ffn_dim": 13824,
166+
"freq_dim": 256,
167+
"in_channels": 36,
168+
"num_attention_heads": 40,
169+
"num_layers": 40,
170+
"out_channels": 16,
171+
"patch_size": [1, 2, 2],
172+
"qk_norm": "rms_norm_across_heads",
173+
"text_dim": 4096,
174+
"rope_max_seq_len": 1024,
175+
"pos_embed_seq_len": 257 * 2,
176+
},
177+
}
138178
return config
139179

140180

@@ -397,7 +437,7 @@ def get_args():
397437
prediction_type="flow_prediction", use_flow_sigmas=True, num_train_timesteps=1000, flow_shift=3.0
398438
)
399439

400-
if "I2V" in args.model_type:
440+
if "I2V" in args.model_type or "FLF2V" in args.model_type:
401441
image_encoder = CLIPVisionModelWithProjection.from_pretrained(
402442
"laion/CLIP-ViT-H-14-laion2B-s32B-b79K", torch_dtype=torch.bfloat16
403443
)

src/diffusers/models/transformers/transformer_wan.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -49,8 +49,10 @@ def __call__(
4949
) -> torch.Tensor:
5050
encoder_hidden_states_img = None
5151
if attn.add_k_proj is not None:
52-
encoder_hidden_states_img = encoder_hidden_states[:, :257]
53-
encoder_hidden_states = encoder_hidden_states[:, 257:]
52+
# 512 is the context length of the text encoder, hardcoded for now
53+
image_context_length = encoder_hidden_states.shape[1] - 512
54+
encoder_hidden_states_img = encoder_hidden_states[:, :image_context_length]
55+
encoder_hidden_states = encoder_hidden_states[:, image_context_length:]
5456
if encoder_hidden_states is None:
5557
encoder_hidden_states = hidden_states
5658

@@ -108,14 +110,23 @@ def apply_rotary_emb(hidden_states: torch.Tensor, freqs: torch.Tensor):
108110

109111

110112
class WanImageEmbedding(torch.nn.Module):
111-
def __init__(self, in_features: int, out_features: int):
113+
def __init__(self, in_features: int, out_features: int, pos_embed_seq_len=None):
112114
super().__init__()
113115

114116
self.norm1 = FP32LayerNorm(in_features)
115117
self.ff = FeedForward(in_features, out_features, mult=1, activation_fn="gelu")
116118
self.norm2 = FP32LayerNorm(out_features)
119+
if pos_embed_seq_len is not None:
120+
self.pos_embed = nn.Parameter(torch.zeros(1, pos_embed_seq_len, in_features))
121+
else:
122+
self.pos_embed = None
117123

118124
def forward(self, encoder_hidden_states_image: torch.Tensor) -> torch.Tensor:
125+
if self.pos_embed is not None:
126+
batch_size, seq_len, embed_dim = encoder_hidden_states_image.shape
127+
encoder_hidden_states_image = encoder_hidden_states_image.view(-1, 2 * seq_len, embed_dim)
128+
encoder_hidden_states_image = encoder_hidden_states_image + self.pos_embed
129+
119130
hidden_states = self.norm1(encoder_hidden_states_image)
120131
hidden_states = self.ff(hidden_states)
121132
hidden_states = self.norm2(hidden_states)
@@ -130,6 +141,7 @@ def __init__(
130141
time_proj_dim: int,
131142
text_embed_dim: int,
132143
image_embed_dim: Optional[int] = None,
144+
pos_embed_seq_len: Optional[int] = None,
133145
):
134146
super().__init__()
135147

@@ -141,7 +153,7 @@ def __init__(
141153

142154
self.image_embedder = None
143155
if image_embed_dim is not None:
144-
self.image_embedder = WanImageEmbedding(image_embed_dim, dim)
156+
self.image_embedder = WanImageEmbedding(image_embed_dim, dim, pos_embed_seq_len=pos_embed_seq_len)
145157

146158
def forward(
147159
self,
@@ -350,6 +362,7 @@ def __init__(
350362
image_dim: Optional[int] = None,
351363
added_kv_proj_dim: Optional[int] = None,
352364
rope_max_seq_len: int = 1024,
365+
pos_embed_seq_len: Optional[int] = None,
353366
) -> None:
354367
super().__init__()
355368

@@ -368,6 +381,7 @@ def __init__(
368381
time_proj_dim=inner_dim * 6,
369382
text_embed_dim=text_dim,
370383
image_embed_dim=image_dim,
384+
pos_embed_seq_len=pos_embed_seq_len,
371385
)
372386

373387
# 3. Transformer blocks

src/diffusers/pipelines/wan/pipeline_wan_i2v.py

Lines changed: 26 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -380,6 +380,7 @@ def prepare_latents(
380380
device: Optional[torch.device] = None,
381381
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
382382
latents: Optional[torch.Tensor] = None,
383+
last_image: Optional[torch.Tensor] = None,
383384
) -> Tuple[torch.Tensor, torch.Tensor]:
384385
num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
385386
latent_height = height // self.vae_scale_factor_spatial
@@ -398,9 +399,16 @@ def prepare_latents(
398399
latents = latents.to(device=device, dtype=dtype)
399400

400401
image = image.unsqueeze(2)
401-
video_condition = torch.cat(
402-
[image, image.new_zeros(image.shape[0], image.shape[1], num_frames - 1, height, width)], dim=2
403-
)
402+
if last_image is None:
403+
video_condition = torch.cat(
404+
[image, image.new_zeros(image.shape[0], image.shape[1], num_frames - 1, height, width)], dim=2
405+
)
406+
else:
407+
last_image = last_image.unsqueeze(2)
408+
video_condition = torch.cat(
409+
[image, image.new_zeros(image.shape[0], image.shape[1], num_frames - 2, height, width), last_image],
410+
dim=2,
411+
)
404412
video_condition = video_condition.to(device=device, dtype=dtype)
405413

406414
latents_mean = (
@@ -424,7 +432,11 @@ def prepare_latents(
424432
latent_condition = (latent_condition - latents_mean) * latents_std
425433

426434
mask_lat_size = torch.ones(batch_size, 1, num_frames, latent_height, latent_width)
427-
mask_lat_size[:, :, list(range(1, num_frames))] = 0
435+
436+
if last_image is None:
437+
mask_lat_size[:, :, list(range(1, num_frames))] = 0
438+
else:
439+
mask_lat_size[:, :, list(range(1, num_frames - 1))] = 0
428440
first_frame_mask = mask_lat_size[:, :, 0:1]
429441
first_frame_mask = torch.repeat_interleave(first_frame_mask, dim=2, repeats=self.vae_scale_factor_temporal)
430442
mask_lat_size = torch.concat([first_frame_mask, mask_lat_size[:, :, 1:, :]], dim=2)
@@ -476,6 +488,7 @@ def __call__(
476488
prompt_embeds: Optional[torch.Tensor] = None,
477489
negative_prompt_embeds: Optional[torch.Tensor] = None,
478490
image_embeds: Optional[torch.Tensor] = None,
491+
last_image: Optional[torch.Tensor] = None,
479492
output_type: Optional[str] = "np",
480493
return_dict: bool = True,
481494
attention_kwargs: Optional[Dict[str, Any]] = None,
@@ -620,7 +633,10 @@ def __call__(
620633
negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype)
621634

622635
if image_embeds is None:
623-
image_embeds = self.encode_image(image, device)
636+
if last_image is None:
637+
image_embeds = self.encode_image(image, device)
638+
else:
639+
image_embeds = self.encode_image([image, last_image], device)
624640
image_embeds = image_embeds.repeat(batch_size, 1, 1)
625641
image_embeds = image_embeds.to(transformer_dtype)
626642

@@ -631,6 +647,10 @@ def __call__(
631647
# 5. Prepare latent variables
632648
num_channels_latents = self.vae.config.z_dim
633649
image = self.video_processor.preprocess(image, height=height, width=width).to(device, dtype=torch.float32)
650+
if last_image is not None:
651+
last_image = self.video_processor.preprocess(last_image, height=height, width=width).to(
652+
device, dtype=torch.float32
653+
)
634654
latents, condition = self.prepare_latents(
635655
image,
636656
batch_size * num_videos_per_prompt,
@@ -642,6 +662,7 @@ def __call__(
642662
device,
643663
generator,
644664
latents,
665+
last_image,
645666
)
646667

647668
# 6. Denoising loop

tests/pipelines/wan/test_wan_image_to_video.py

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,3 +160,90 @@ def test_attention_slicing_forward_pass(self):
160160
@unittest.skip("TODO: revisit failing as it requires a very high threshold to pass")
161161
def test_inference_batch_single_identical(self):
162162
pass
163+
164+
165+
class WanFLFToVideoPipelineFastTests(WanImageToVideoPipelineFastTests):
166+
def get_dummy_components(self):
167+
torch.manual_seed(0)
168+
vae = AutoencoderKLWan(
169+
base_dim=3,
170+
z_dim=16,
171+
dim_mult=[1, 1, 1, 1],
172+
num_res_blocks=1,
173+
temperal_downsample=[False, True, True],
174+
)
175+
176+
torch.manual_seed(0)
177+
# TODO: impl FlowDPMSolverMultistepScheduler
178+
scheduler = FlowMatchEulerDiscreteScheduler(shift=7.0)
179+
text_encoder = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
180+
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
181+
182+
torch.manual_seed(0)
183+
transformer = WanTransformer3DModel(
184+
patch_size=(1, 2, 2),
185+
num_attention_heads=2,
186+
attention_head_dim=12,
187+
in_channels=36,
188+
out_channels=16,
189+
text_dim=32,
190+
freq_dim=256,
191+
ffn_dim=32,
192+
num_layers=2,
193+
cross_attn_norm=True,
194+
qk_norm="rms_norm_across_heads",
195+
rope_max_seq_len=32,
196+
image_dim=4,
197+
pos_embed_seq_len=2 * (4 * 4 + 1),
198+
)
199+
200+
torch.manual_seed(0)
201+
image_encoder_config = CLIPVisionConfig(
202+
hidden_size=4,
203+
projection_dim=4,
204+
num_hidden_layers=2,
205+
num_attention_heads=2,
206+
image_size=4,
207+
intermediate_size=16,
208+
patch_size=1,
209+
)
210+
image_encoder = CLIPVisionModelWithProjection(image_encoder_config)
211+
212+
torch.manual_seed(0)
213+
image_processor = CLIPImageProcessor(crop_size=4, size=4)
214+
215+
components = {
216+
"transformer": transformer,
217+
"vae": vae,
218+
"scheduler": scheduler,
219+
"text_encoder": text_encoder,
220+
"tokenizer": tokenizer,
221+
"image_encoder": image_encoder,
222+
"image_processor": image_processor,
223+
}
224+
return components
225+
226+
def get_dummy_inputs(self, device, seed=0):
227+
if str(device).startswith("mps"):
228+
generator = torch.manual_seed(seed)
229+
else:
230+
generator = torch.Generator(device=device).manual_seed(seed)
231+
image_height = 16
232+
image_width = 16
233+
image = Image.new("RGB", (image_width, image_height))
234+
last_image = Image.new("RGB", (image_width, image_height))
235+
inputs = {
236+
"image": image,
237+
"last_image": last_image,
238+
"prompt": "dance monkey",
239+
"negative_prompt": "negative",
240+
"height": image_height,
241+
"width": image_width,
242+
"generator": generator,
243+
"num_inference_steps": 2,
244+
"guidance_scale": 6.0,
245+
"num_frames": 9,
246+
"max_sequence_length": 16,
247+
"output_type": "pt",
248+
}
249+
return inputs

0 commit comments

Comments
 (0)