Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 75 additions & 14 deletions src/diffusers/pipelines/wan/pipeline_wan_video2video.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import html
import inspect
import warnings
from typing import Any, Callable, Dict, List, Optional, Union

import regex as re
Expand Down Expand Up @@ -502,6 +503,8 @@ def __call__(
] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
max_sequence_length: int = 512,
control_hidden_states: Optional[torch.Tensor] = None,
control_hidden_states_scale: Optional[torch.Tensor] = None,
):
r"""
The call function to the pipeline for generation.
Expand Down Expand Up @@ -559,6 +562,13 @@ def __call__(
max_sequence_length (`int`, defaults to `512`):
The maximum sequence length of the text encoder. If the prompt is longer than this, it will be
truncated. If the prompt is shorter, it will be padded to this length.
control_hidden_states (`torch.Tensor`, *optional*):
Control tensor for the VACE control path. Shape: `(B, C, T_patch, H_patch, W_patch)`. If omitted, a
neutral zero tensor of the correct size/dtype is created automatically. **If the underlying transformer
does not support these kwargs, this argument is ignored.**
control_hidden_states_scale (`torch.Tensor`, *optional*):
1D tensor of scaling factors for VACE layers (length = number of VACE layers). Defaults to ones.
**Ignored if unsupported.**

Examples:

Expand Down Expand Up @@ -593,6 +603,20 @@ def __call__(
self._attention_kwargs = attention_kwargs
self._current_timestep = None
self._interrupt = False
base_tr = (
self.transformer.get_base_model() if hasattr(self.transformer, "get_base_model") else self.transformer
)
_sig = inspect.signature(base_tr.forward)
_supports_control = (
"control_hidden_states" in _sig.parameters and "control_hidden_states_scale" in _sig.parameters
)
# Warn if user passed control kwargs but model won't consume them
if not _supports_control and (control_hidden_states is not None or control_hidden_states_scale is not None):
warnings.warn(
"control_hidden_states/control_hidden_states_scale were provided, but the underlying transformer "
"does not accept these kwargs; they will be ignored.",
stacklevel=2,
)

device = self._execution_device

Expand Down Expand Up @@ -647,6 +671,30 @@ def __call__(
latent_timestep,
)

# Precompute shapes we’ll need
B = batch_size * num_videos_per_prompt

# Build neutral control tensors only if the base transformer supports them
if _supports_control:
cfg_tr = self.transformer.config # FrozenDict-like
C_ctrl = cfg_tr.get("vace_in_channels", cfg_tr.get("out_channels", cfg_tr.get("in_channels", 320)))
ps = cfg_tr.get("patch_size", (1, 1, 1))
if isinstance(ps, int):
pt = ph = pw = ps
else:
pt, ph, pw = (ps[0], ps[1], ps[2]) if len(ps) == 3 else (ps[0], ps[0], ps[0])

# On first use, create neutral one-token control
if control_hidden_states is None:
control_hidden_states = torch.zeros(
(B, int(C_ctrl), int(pt), int(ph), int(pw)), device=device, dtype=transformer_dtype
)
# Layer-wise scale vector (not batched)
if control_hidden_states_scale is None:
vls = cfg_tr.get("vace_layers", [])
n_layers = len(vls) if isinstance(vls, (list, tuple)) else int(vls or 0)
control_hidden_states_scale = torch.ones(max(1, n_layers), device=device, dtype=transformer_dtype)

# 6. Denoising loop
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
self._num_timesteps = len(timesteps)
Expand All @@ -660,22 +708,35 @@ def __call__(
latent_model_input = latents.to(transformer_dtype)
timestep = t.expand(latents.shape[0])

noise_pred = self.transformer(
hidden_states=latent_model_input,
timestep=timestep,
encoder_hidden_states=prompt_embeds,
attention_kwargs=attention_kwargs,
return_dict=False,
)[0]
# Prepare kwargs for transformer call; keep identical for cond/uncond (swap only encoder_hidden_states)
call_kwargs = {
"hidden_states": latent_model_input,
"timestep": timestep,
"encoder_hidden_states": prompt_embeds,
"attention_kwargs": attention_kwargs,
"return_dict": False,
}

# If supported, attach control tensors; ensure batch/device/dtype match latent input
if _supports_control:
if control_hidden_states.shape[0] != latent_model_input.shape[0]:
control_hidden_states = control_hidden_states.expand(
latent_model_input.shape[0], -1, -1, -1, -1
)
call_kwargs["control_hidden_states"] = control_hidden_states.to(
device=latent_model_input.device, dtype=transformer_dtype
)
call_kwargs["control_hidden_states_scale"] = control_hidden_states_scale.to(
device=latent_model_input.device, dtype=transformer_dtype
)

# Cond pass
noise_pred = self.transformer(**call_kwargs)[0]

if self.do_classifier_free_guidance:
noise_uncond = self.transformer(
hidden_states=latent_model_input,
timestep=timestep,
encoder_hidden_states=negative_prompt_embeds,
attention_kwargs=attention_kwargs,
return_dict=False,
)[0]
# Uncond pass: swap encoder_hidden_states; keep control kwargs identical
call_kwargs["encoder_hidden_states"] = negative_prompt_embeds
noise_uncond = self.transformer(**call_kwargs)[0]
noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond)

# compute the previous noisy sample x_t -> x_t-1
Expand Down
106 changes: 106 additions & 0 deletions tests/pipelines/wan/test_wan_video_to_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import inspect
import unittest

import torch
Expand Down Expand Up @@ -50,6 +51,12 @@ class WanVideoToVideoPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
test_xformers_attention = False
supports_dduf = False

def _supports_control_kwargs(self, transformer) -> bool:
"""Return True if the base transformer's forward() accepts VACE control kwargs."""
base = transformer.get_base_model() if hasattr(transformer, "get_base_model") else transformer
sig = inspect.signature(base.forward)
return "control_hidden_states" in sig.parameters and "control_hidden_states_scale" in sig.parameters

def get_dummy_components(self):
torch.manual_seed(0)
vae = AutoencoderKLWan(
Expand Down Expand Up @@ -147,3 +154,102 @@ def test_float16_inference(self):
)
def test_save_load_float16(self):
pass

def test_neutral_control_injection_no_crash_latent(self):
device = "cpu"

# Reuse the same tiny components
components = self.get_dummy_components()
pipe = self.pipeline_class(**components).to(device)
pipe.set_progress_bar_config(disable=None)

# If transformer doesn't support control kwargs, this test isn't applicable.
if not self._supports_control_kwargs(pipe.transformer):
self.skipTest("Transformer doesn't accept VACE control kwargs; skipping control injection test.")

# --- Ensure VACE fields exist for control tensor sizing ---
# Prefer real module in_channels if present
pe = getattr(pipe.transformer, "vace_patch_embedding", None)
if pe is not None and hasattr(pe, "in_channels"):
vace_in = int(pe.in_channels)
else:
# fallback to model config fields
vace_in = int(getattr(pipe.transformer.config, "vace_in_channels", pipe.transformer.config.in_channels))
# also set it to help the pipeline code path
pipe.transformer.config.vace_in_channels = vace_in

# vace_layers: ensure non-empty so scale vector has length >=1
if not hasattr(pipe.transformer.config, "vace_layers"):
pipe.transformer.config.vace_layers = [0, 1]

# Patch: we run in latent mode; skip VAE decode & video preprocessing
# Build tiny latents matching transformer.config.in_channels
C = int(pipe.transformer.config.in_channels)
# Very small T/H/W to keep speed
latents = torch.zeros((1, C, 2, 8, 8), device=device, dtype=torch.float32)

out = pipe(
video=None,
prompt="test",
negative_prompt=None,
height=16,
width=16,
num_inference_steps=2,
guidance_scale=1.0, # disable CFG branch to keep path minimal
strength=0.5,
generator=None,
latents=latents, # <- latent path, so we don’t need real VAE/video_processor
prompt_embeds=None,
negative_prompt_embeds=None,
output_type="latent", # <- prevents decode/postprocess
return_dict=True,
max_sequence_length=16,
).frames

# Assert: no crash and the latent shape is preserved
self.assertIsInstance(out, torch.Tensor)
self.assertEqual(tuple(out.shape), tuple(latents.shape))

def test_neutral_control_injection_with_cfg(self):
device = "cpu"

components = self.get_dummy_components()
pipe = self.pipeline_class(**components).to(device)
pipe.set_progress_bar_config(disable=None)

if not self._supports_control_kwargs(pipe.transformer):
self.skipTest("Transformer doesn't accept VACE control kwargs; skipping control+CFG test.")

# Ensure VACE sizing hints exist (as above)
pe = getattr(pipe.transformer, "vace_patch_embedding", None)
if pe is not None and hasattr(pe, "in_channels"):
vace_in = int(pe.in_channels)
else:
vace_in = int(getattr(pipe.transformer.config, "vace_in_channels", pipe.transformer.config.in_channels))
pipe.transformer.config.vace_in_channels = vace_in
if not hasattr(pipe.transformer.config, "vace_layers"):
pipe.transformer.config.vace_layers = [0, 1, 2]

C = int(pipe.transformer.config.in_channels)
latents = torch.zeros((1, C, 2, 8, 8), device=device, dtype=torch.float32)

out = pipe(
video=None,
prompt="test",
negative_prompt="",
height=16,
width=16,
num_inference_steps=2,
guidance_scale=3.5, # trigger CFG (uncond) path
strength=0.5,
generator=None,
latents=latents,
prompt_embeds=None,
negative_prompt_embeds=None,
output_type="latent",
return_dict=True,
max_sequence_length=16,
).frames

self.assertIsInstance(out, torch.Tensor)
self.assertEqual(tuple(out.shape), tuple(latents.shape))