Skip to content

Conversation

@sywangyi
Copy link
Contributor

@sywangyi sywangyi commented Oct 30, 2025

fix the crash when testing CP for wan2.2-TI2V-5B

test script:

import random

import numpy as np
import torch
from torch import distributed as dist

from diffusers import AutoencoderKLWan, ContextParallelConfig, WanPipeline
from diffusers.hooks.group_offloading import apply_group_offloading
from diffusers.utils import export_to_video


model_id="Wan-AI/Wan2.2-TI2V-5B-Diffusers"

def setup_distributed():
    if not dist.is_initialized():
        dist.init_process_group(backend="nccl")
    device = torch.device(f"cuda:{dist.get_rank()}")
    torch.cuda.set_device(device)
    return device


def set_seed_for_all_ranks(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    generator = torch.Generator(device="cuda")
    generator.manual_seed(seed)
    return generator


device = setup_distributed()
generator = set_seed_for_all_ranks(42)
onload_device = device
offload_device = torch.device("cpu")

vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
# group-offloading
pipe = WanPipeline.from_pretrained(
    model_id,
    vae=vae,
    torch_dtype=torch.bfloat16,
)
ulysses_degree = torch.distributed.get_world_size()
pipe.transformer.set_attention_backend("_native_cudnn")
pipe.transformer.enable_parallelism(config=ContextParallelConfig(ulysses_degree=ulysses_degree))
apply_group_offloading(pipe.text_encoder,
    onload_device=onload_device,
    offload_device=offload_device,
    offload_type="leaf_level",
    use_stream=True,
)

pipe.transformer.enable_group_offload(
    onload_device=onload_device,
    offload_device=offload_device,
    offload_type="leaf_level",
    use_stream=True,
)
pipe.vae.enable_group_offload(onload_device=onload_device, offload_type="leaf_level", use_stream=True)

pipe.vae.enable_tiling(tile_sample_min_height=480,tile_sample_min_width=960,tile_sample_stride_height=352,tile_sample_stride_width=640)
height = 704
width = 1280
num_frames = 121
num_inference_steps = 50
guidance_scale = 5.0


prompt = "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage."
negative_prompt = "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩>残留,丑陋的,残缺的>,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂>乱的背景,三条腿,背>景
人很多,倒着走"

output = pipe(
    prompt=prompt,
    negative_prompt=negative_prompt,
    height=height,
    width=width,
    num_frames=num_frames,
    guidance_scale=guidance_scale,
    num_inference_steps=num_inference_steps,
    generator=generator,
).frames[0]
if torch.distributed.get_rank() == 0:
    export_to_video(output, "5bit2v_output.mp4", fps=24)
if dist.is_initialized():
    torch.distributed.destroy_process_group()

Signed-off-by: Wang, Yi <yi.a.wang@intel.com>
@sywangyi
Copy link
Contributor Author

torchrun --nproc-per-node 2 test.py

crash stack:

[rank1]: Traceback (most recent call last):
[rank1]:   File "/mnt/disk3/wangyi/diffusers/test_14B_cp_offload.py", line 72, in <module>
[rank1]:     output = pipe(
[rank1]:              ^^^^^
[rank1]:   File "/mnt/disk0/wangyi/miniforge3/envs/transformers/lib/python3.11/site-packages/torch/utils/_contextlib.py", line 120, in decorate_context
[rank1]:     return func(*args, **kwargs)
[rank1]:            ^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/mnt/disk3/wangyi/diffusers/src/diffusers/pipelines/wan/pipeline_wan.py", line 593, in __call__
[rank1]:     noise_pred = current_model(
[rank1]:                  ^^^^^^^^^^^^^^
[rank1]:   File "/mnt/disk0/wangyi/miniforge3/envs/transformers/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
[rank1]:     return self._call_impl(*args, **kwargs)
[rank1]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/mnt/disk0/wangyi/miniforge3/envs/transformers/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
[rank1]:     return forward_call(*args, **kwargs)
[rank1]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/mnt/disk3/wangyi/diffusers/src/diffusers/hooks/hooks.py", line 189, in new_forward
[rank1]:     output = function_reference.forward(*args, **kwargs)
[rank1]:              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/mnt/disk3/wangyi/diffusers/src/diffusers/hooks/hooks.py", line 189, in new_forward
[rank1]:     output = function_reference.forward(*args, **kwargs)
[rank1]:              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/mnt/disk3/wangyi/diffusers/src/diffusers/models/transformers/transformer_wan.py", line 680, in forward
[rank1]:     hidden_states = block(hidden_states, encoder_hidden_states, timestep_proj, rotary_emb)
[rank1]:                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/mnt/disk0/wangyi/miniforge3/envs/transformers/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
[rank1]:     return self._call_impl(*args, **kwargs)
[rank1]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/mnt/disk0/wangyi/miniforge3/envs/transformers/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
[rank1]:     return forward_call(*args, **kwargs)
[rank1]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/mnt/disk3/wangyi/diffusers/src/diffusers/hooks/hooks.py", line 189, in new_forward
[rank1]:     output = function_reference.forward(*args, **kwargs)
[rank1]:              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/mnt/disk3/wangyi/diffusers/src/diffusers/hooks/hooks.py", line 189, in new_forward
[rank1]:     output = function_reference.forward(*args, **kwargs)
[rank1]:              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/mnt/disk3/wangyi/diffusers/src/diffusers/hooks/hooks.py", line 189, in new_forward
[rank1]:     output = function_reference.forward(*args, **kwargs)
[rank1]:              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   [Previous line repeated 1 more time]
[rank1]:   File "/mnt/disk3/wangyi/diffusers/src/diffusers/models/transformers/transformer_wan.py", line 482, in forward
[rank1]:     norm_hidden_states = (self.norm1(hidden_states.float()) * (1 + scale_msa) + shift_msa).type_as(hidden_states)
[rank1]:                           ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^~~~~~~~~~~~~~~~~
[rank1]: RuntimeError: The size of tensor a (13640) must match the size of tensor b (27280) at non-singleton dimension 1
[rank1]:[W1030 12:11:14.705123356 ProcessGroupNCCL.cpp:1538] Warning: WARNING: destroy_process_group() was not called before program exit, which can leak resource, please see https://pytorch.org/docs/stable/distributed.html#shutdown (function operator())
  0%|                                                                                                                                                       | 0/1
[rank0]: Traceback (most recent call last):
[rank0]:   File "/mnt/disk3/wangyi/diffusers/test_14B_cp_offload.py", line 72, in <module>
[rank0]:     output = pipe(
[rank0]:              ^^^^^
[rank0]:   File "/mnt/disk0/wangyi/miniforge3/envs/transformers/lib/python3.11/site-packages/torch/utils/_contextlib.py", line 120, in decorate_context
[rank0]:     return func(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/mnt/disk3/wangyi/diffusers/src/diffusers/pipelines/wan/pipeline_wan.py", line 593, in __call__
[rank0]:     noise_pred = current_model(
[rank0]:                  ^^^^^^^^^^^^^^
[rank0]:   File "/mnt/disk0/wangyi/miniforge3/envs/transformers/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/mnt/disk0/wangyi/miniforge3/envs/transformers/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/mnt/disk3/wangyi/diffusers/src/diffusers/hooks/hooks.py", line 189, in new_forward
[rank0]:     output = function_reference.forward(*args, **kwargs)
[rank0]:              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/mnt/disk3/wangyi/diffusers/src/diffusers/hooks/hooks.py", line 189, in new_forward
[rank0]:     output = function_reference.forward(*args, **kwargs)
[rank0]:              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/mnt/disk3/wangyi/diffusers/src/diffusers/models/transformers/transformer_wan.py", line 680, in forward
[rank0]:     hidden_states = block(hidden_states, encoder_hidden_states, timestep_proj, rotary_emb)
[rank0]:                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/mnt/disk0/wangyi/miniforge3/envs/transformers/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/mnt/disk0/wangyi/miniforge3/envs/transformers/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/mnt/disk3/wangyi/diffusers/src/diffusers/hooks/hooks.py", line 189, in new_forward
[rank0]:     output = function_reference.forward(*args, **kwargs)
[rank0]:              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/mnt/disk3/wangyi/diffusers/src/diffusers/hooks/hooks.py", line 189, in new_forward
[rank0]:     output = function_reference.forward(*args, **kwargs)
[rank0]:              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/mnt/disk3/wangyi/diffusers/src/diffusers/hooks/hooks.py", line 189, in new_forward
[rank0]:     output = function_reference.forward(*args, **kwargs)
[rank0]:              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   [Previous line repeated 1 more time]
[rank0]:   File "/mnt/disk3/wangyi/diffusers/src/diffusers/models/transformers/transformer_wan.py", line 482, in forward
[rank0]:     norm_hidden_states = (self.norm1(hidden_states.float()) * (1 + scale_msa) + shift_msa).type_as(hidden_states)
[rank0]:                           ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^~~~~~~~~~~~~~~~~
[rank0]: RuntimeError: The size of tensor a (13640) must match the size of tensor b (27280) at non-singleton dimension 1

@sywangyi
Copy link
Contributor Author

@yiyixuxu @sayakpaul please help review

@sayakpaul
Copy link
Member

Could you also supplement an output with the fix?

Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Could you please explain the changes and also provide an example output?

@sayakpaul sayakpaul requested a review from DN6 October 30, 2025 05:50
@sywangyi
Copy link
Contributor Author

seems I can not attach the video here, blocked may be ....

image

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants