Skip to content
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions src/diffusers/models/transformers/transformer_wan.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import torch.nn.functional as F

from ...configuration_utils import ConfigMixin, register_to_config
from ...hooks.context_parallel import EquipartitionSharder
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers
from ...utils.torch_utils import maybe_allow_in_graph
Expand Down Expand Up @@ -660,6 +661,15 @@ def forward(
timestep, encoder_hidden_states, encoder_hidden_states_image, timestep_seq_len=ts_seq_len
)
if ts_seq_len is not None:
# Check if running under context parallel and split along seq_len dimension
if hasattr(self, '_parallel_config') and self._parallel_config is not None:
cp_config = getattr(self._parallel_config, 'context_parallel_config', None)
if cp_config is not None and cp_config._world_size > 1:
timestep_proj = EquipartitionSharder.shard(
timestep_proj,
dim=1,
mesh=cp_config._flattened_mesh
)
# batch_size, seq_len, 6, inner_dim
timestep_proj = timestep_proj.unflatten(2, (6, -1))
else:
Expand All @@ -681,6 +691,15 @@ def forward(

# 5. Output norm, projection & unpatchify
if temb.ndim == 3:
# Check if running under context parallel and split along seq_len dimension
if hasattr(self, '_parallel_config') and self._parallel_config is not None:
cp_config = getattr(self._parallel_config, 'context_parallel_config', None)
if cp_config is not None and cp_config._world_size > 1:
temb = EquipartitionSharder.shard(
temb,
dim=1,
mesh=cp_config._flattened_mesh
)
# batch_size, seq_len, inner_dim (wan 2.2 ti2v)
shift, scale = (self.scale_shift_table.unsqueeze(0).to(temb.device) + temb.unsqueeze(2)).chunk(2, dim=2)
shift = shift.squeeze(2)
Expand Down