Skip to content

Commit 50329d7

Browse files
committed
Add initial Wan Animate transformer tests
1 parent 8216aef commit 50329d7

File tree

1 file changed

+121
-0
lines changed

1 file changed

+121
-0
lines changed
Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
# Copyright 2025 HuggingFace Inc.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import unittest
16+
17+
import torch
18+
19+
from diffusers import WanAnimateTransformer3DModel
20+
21+
from ...testing_utils import (
22+
enable_full_determinism,
23+
torch_device,
24+
)
25+
from ..test_modeling_common import ModelTesterMixin, TorchCompileTesterMixin
26+
27+
28+
enable_full_determinism()
29+
30+
31+
class WanAnimateTransformer3DTests(ModelTesterMixin, unittest.TestCase):
32+
model_class = WanAnimateTransformer3DModel
33+
main_input_name = "hidden_states"
34+
uses_custom_attn_processor = True
35+
36+
@property
37+
def dummy_input(self):
38+
batch_size = 1
39+
num_channels = 4
40+
num_frames = 20 # To make the shapes work out; for complicated reasons we want 21 to divide num_frames + 1
41+
height = 16
42+
width = 16
43+
text_encoder_embedding_dim = 16
44+
sequence_length = 12
45+
46+
clip_seq_len = 12
47+
clip_dim = 16
48+
49+
inference_segment_length = 77 # The inference segment length in the full Wan2.2-Animate-14B model
50+
face_height = 8
51+
face_width = 8
52+
53+
hidden_states = torch.randn((batch_size, 2 * num_channels + 4, num_frames + 1, height, width)).to(torch_device)
54+
timestep = torch.randint(0, 1000, size=(batch_size,)).to(torch_device)
55+
encoder_hidden_states = torch.randn((batch_size, sequence_length, text_encoder_embedding_dim)).to(torch_device)
56+
clip_ref_features = torch.randn((batch_size, clip_seq_len, clip_dim)).to(torch_device)
57+
pose_latents = torch.randn((batch_size, num_channels, num_frames, height, width)).to(torch_device)
58+
face_pixel_values = torch.randn(
59+
(batch_size, 3, inference_segment_length, face_height, face_width)
60+
).to(torch_device)
61+
62+
return {
63+
"hidden_states": hidden_states,
64+
"timestep": timestep,
65+
"encoder_hidden_states": encoder_hidden_states,
66+
"encoder_hidden_states_image": clip_ref_features,
67+
"pose_hidden_states": pose_latents,
68+
"face_pixel_values": face_pixel_values,
69+
}
70+
71+
@property
72+
def input_shape(self):
73+
return (4, 1, 16, 16)
74+
75+
@property
76+
def output_shape(self):
77+
return (4, 1, 16, 16)
78+
79+
def prepare_init_args_and_inputs_for_common(self):
80+
init_dict = {
81+
"patch_size": (1, 2, 2),
82+
"num_attention_heads": 2,
83+
"attention_head_dim": 12,
84+
"in_channels": 12, # 2 * C + 4 = 2 * 4 + 4 = 12
85+
"latent_channels": 4,
86+
"out_channels": 4,
87+
"text_dim": 16,
88+
"freq_dim": 256,
89+
"ffn_dim": 32,
90+
"num_layers": 2,
91+
"cross_attn_norm": True,
92+
"qk_norm": "rms_norm_across_heads",
93+
"image_dim": 16,
94+
"rope_max_seq_len": 32,
95+
"motion_encoder_size": 8, # Start of Wan Animate-specific config
96+
"motion_style_dim": 8,
97+
"motion_dim": 4,
98+
"motion_encoder_dim": 16,
99+
"face_encoder_hidden_dim": 16,
100+
"face_encoder_num_heads": 2,
101+
"inject_face_latents_blocks": 2,
102+
}
103+
inputs_dict = self.dummy_input
104+
return init_dict, inputs_dict
105+
106+
def test_gradient_checkpointing_is_applied(self):
107+
expected_set = {"WanAnimateTransformer3DModel"}
108+
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
109+
110+
# Override test_output because the transformer output is expected to have less channels than the main transformer
111+
# input.
112+
def test_output(self):
113+
expected_output_shape = (1, 4, 21, 16, 16)
114+
super().test_output(expected_output_shape=expected_output_shape)
115+
116+
117+
class WanAnimateTransformerCompileTests(TorchCompileTesterMixin, unittest.TestCase):
118+
model_class = WanAnimateTransformer3DModel
119+
120+
def prepare_init_args_and_inputs_for_common(self):
121+
return WanAnimateTransformer3DTests().prepare_init_args_and_inputs_for_common()

0 commit comments

Comments
 (0)