Skip to content

[refactor] Wan single file implementation #11918

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 38 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
d7b9e42
update
a-r-r-o-w Jul 14, 2025
7e97e43
update
a-r-r-o-w Jul 14, 2025
aaf2df8
update
a-r-r-o-w Jul 14, 2025
ecabd2a
add coauthor
a-r-r-o-w Jul 14, 2025
ff21b7f
improve test
a-r-r-o-w Jul 14, 2025
da9bfb1
Merge branch 'to-single-file/flux' into to-single-file/wan
a-r-r-o-w Jul 14, 2025
b8f7fe6
handle ip adapter params correctly
a-r-r-o-w Jul 14, 2025
17b678f
Merge branch 'main' into to-single-file/flux
a-r-r-o-w Jul 15, 2025
0cda91d
fix chroma qkv fusion test
a-r-r-o-w Jul 15, 2025
bc64f12
fix fastercache implementation
a-r-r-o-w Jul 15, 2025
9acadf3
Merge branch 'to-single-file/flux' into to-single-file/wan
a-r-r-o-w Jul 15, 2025
c72e72c
remove set_attention_backend related code
a-r-r-o-w Jul 15, 2025
a0b276d
fix more tests
a-r-r-o-w Jul 15, 2025
1b77c56
Merge branch 'to-single-file/flux' into to-single-file/wan
a-r-r-o-w Jul 15, 2025
c141520
fight more tests
a-r-r-o-w Jul 15, 2025
4dcd672
add back set_attention_backend
a-r-r-o-w Jul 15, 2025
576da52
update
a-r-r-o-w Jul 15, 2025
e909b73
update
a-r-r-o-w Jul 15, 2025
1e7217f
make style
a-r-r-o-w Jul 15, 2025
4f52e34
make fix-copies
a-r-r-o-w Jul 15, 2025
d9c1683
make ip adapter processor compatible with attention dispatcher
a-r-r-o-w Jul 15, 2025
a73cb39
refactor chroma as well
a-r-r-o-w Jul 15, 2025
6c841e8
Merge branch 'to-single-file/flux' into to-single-file/wan
a-r-r-o-w Jul 15, 2025
4940b21
attnetion dispatcher support
a-r-r-o-w Jul 15, 2025
5b8927a
remove transpose; fix rope shape
a-r-r-o-w Jul 15, 2025
90cf967
remove rmsnorm assert
a-r-r-o-w Jul 16, 2025
91a7b97
minify and deprecate npu/xla processors
a-r-r-o-w Jul 16, 2025
1e6b1c5
remove rmsnorm assert
a-r-r-o-w Jul 16, 2025
251bb61
minify and deprecate npu/xla processors
a-r-r-o-w Jul 16, 2025
84d2c84
Merge branch 'main' into to-single-file/flux
a-r-r-o-w Jul 16, 2025
da994a7
Merge branch 'to-single-file/flux' into to-single-file/wan
a-r-r-o-w Jul 16, 2025
c857a8a
Merge branch 'main' into to-single-file/wan
a-r-r-o-w Jul 17, 2025
df63133
Merge branch 'main' into to-single-file/wan
a-r-r-o-w Jul 18, 2025
1465133
Merge branch 'main' into to-single-file/wan
a-r-r-o-w Jul 21, 2025
8468350
Merge branch 'main' into to-single-file/wan
a-r-r-o-w Jul 23, 2025
08152c5
Merge branch 'main' into to-single-file/wan
a-r-r-o-w Jul 23, 2025
6067995
update
a-r-r-o-w Jul 23, 2025
1dc9c65
Merge branch 'main' into to-single-file/wan
a-r-r-o-w Jul 25, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
251 changes: 197 additions & 54 deletions src/diffusers/models/transformers/transformer_wan.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,10 @@

from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers
from ...utils.torch_utils import maybe_allow_in_graph
from ..attention import FeedForward
from ..attention_processor import Attention
from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward
from ..attention_dispatch import dispatch_attention_fn
from ..cache_utils import CacheMixin
from ..embeddings import PixArtAlphaTextProjection, TimestepEmbedding, Timesteps, get_1d_rotary_pos_embed
from ..modeling_outputs import Transformer2DModelOutput
Expand All @@ -35,40 +35,67 @@
logger = logging.get_logger(__name__) # pylint: disable=invalid-name


class WanAttnProcessor2_0:
def _get_qkv_projections(attn: "WanAttention", hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor):
# encoder_hidden_states is only passed for cross-attention
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states

if attn.fused_projections:
if attn.cross_attention_dim_head is None:
# In self-attention layers, we can fuse the entire QKV projection into a single linear
query, key, value = attn.to_qkv(hidden_states).chunk(3, dim=-1)
else:
# In cross-attention layers, we can only fuse the KV projections into a single linear
query = attn.to_q(hidden_states)
key, value = attn.to_kv(encoder_hidden_states).chunk(2, dim=-1)
else:
query = attn.to_q(hidden_states)
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
return query, key, value


def _get_added_kv_projections(attn: "WanAttention", encoder_hidden_states_img: torch.Tensor):
if attn.fused_projections:
key_img, value_img = attn.to_added_kv(encoder_hidden_states_img).chunk(2, dim=-1)
else:
key_img = attn.add_k_proj(encoder_hidden_states_img)
value_img = attn.add_v_proj(encoder_hidden_states_img)
return key_img, value_img


class WanAttnProcessor:
_attention_backend = None

def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError("WanAttnProcessor2_0 requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0.")
raise ImportError(
"WanAttnProcessor2_0 requires PyTorch 2.0. To use it, please upgrade PyTorch to version 2.0 or higher."
)

def __call__(
self,
attn: Attention,
attn: "WanAttention",
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
rotary_emb: Optional[torch.Tensor] = None,
rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
) -> torch.Tensor:
encoder_hidden_states_img = None
if attn.add_k_proj is not None:
# 512 is the context length of the text encoder, hardcoded for now
image_context_length = encoder_hidden_states.shape[1] - 512
encoder_hidden_states_img = encoder_hidden_states[:, :image_context_length]
encoder_hidden_states = encoder_hidden_states[:, image_context_length:]
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states

query = attn.to_q(hidden_states)
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
query, key, value = _get_qkv_projections(attn, hidden_states, encoder_hidden_states)

if attn.norm_q is not None:
query = attn.norm_q(query)
if attn.norm_k is not None:
key = attn.norm_k(key)
query = attn.norm_q(query)
key = attn.norm_k(key)

query = query.unflatten(2, (attn.heads, -1)).transpose(1, 2)
key = key.unflatten(2, (attn.heads, -1)).transpose(1, 2)
value = value.unflatten(2, (attn.heads, -1)).transpose(1, 2)
query = query.unflatten(2, (attn.heads, -1))
key = key.unflatten(2, (attn.heads, -1))
value = value.unflatten(2, (attn.heads, -1))

if rotary_emb is not None:

Expand All @@ -77,8 +104,7 @@ def apply_rotary_emb(
freqs_cos: torch.Tensor,
freqs_sin: torch.Tensor,
):
x = hidden_states.view(*hidden_states.shape[:-1], -1, 2)
x1, x2 = x[..., 0], x[..., 1]
x1, x2 = hidden_states.unflatten(-1, (-1, 2)).unbind(-1)
cos = freqs_cos[..., 0::2]
sin = freqs_sin[..., 1::2]
out = torch.empty_like(hidden_states)
Expand All @@ -92,23 +118,34 @@ def apply_rotary_emb(
# I2V task
hidden_states_img = None
if encoder_hidden_states_img is not None:
key_img = attn.add_k_proj(encoder_hidden_states_img)
key_img, value_img = _get_added_kv_projections(attn, encoder_hidden_states_img)
key_img = attn.norm_added_k(key_img)
value_img = attn.add_v_proj(encoder_hidden_states_img)

key_img = key_img.unflatten(2, (attn.heads, -1)).transpose(1, 2)
value_img = value_img.unflatten(2, (attn.heads, -1)).transpose(1, 2)

hidden_states_img = F.scaled_dot_product_attention(
query, key_img, value_img, attn_mask=None, dropout_p=0.0, is_causal=False
key_img = key_img.unflatten(2, (attn.heads, -1))
value_img = value_img.unflatten(2, (attn.heads, -1))

hidden_states_img = dispatch_attention_fn(
query,
key_img,
value_img,
attn_mask=None,
dropout_p=0.0,
is_causal=False,
backend=self._attention_backend,
)
hidden_states_img = hidden_states_img.transpose(1, 2).flatten(2, 3)
hidden_states_img = hidden_states_img.flatten(2, 3)
hidden_states_img = hidden_states_img.type_as(query)

hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
hidden_states = dispatch_attention_fn(
query,
key,
value,
attn_mask=attention_mask,
dropout_p=0.0,
is_causal=False,
backend=self._attention_backend,
)
hidden_states = hidden_states.transpose(1, 2).flatten(2, 3)
hidden_states = hidden_states.flatten(2, 3)
hidden_states = hidden_states.type_as(query)

if hidden_states_img is not None:
Expand All @@ -119,6 +156,119 @@ def apply_rotary_emb(
return hidden_states


class WanAttnProcessor2_0:
def __new__(cls, *args, **kwargs):
deprecation_message = (
"The WanAttnProcessor2_0 class is deprecated and will be removed in a future version. "
"Please use WanAttnProcessor instead. "
)
deprecate("WanAttnProcessor2_0", "1.0.0", deprecation_message, standard_warn=False)
return WanAttnProcessor(*args, **kwargs)


class WanAttention(torch.nn.Module, AttentionModuleMixin):
_default_processor_cls = WanAttnProcessor
_available_processors = [WanAttnProcessor]

def __init__(
self,
dim: int,
heads: int = 8,
dim_head: int = 64,
eps: float = 1e-5,
dropout: float = 0.0,
added_kv_proj_dim: Optional[int] = None,
cross_attention_dim_head: Optional[int] = None,
processor=None,
):
super().__init__()

self.inner_dim = dim_head * heads
self.heads = heads
self.added_kv_proj_dim = added_kv_proj_dim
self.cross_attention_dim_head = cross_attention_dim_head
self.kv_inner_dim = self.inner_dim if cross_attention_dim_head is None else cross_attention_dim_head * heads

self.to_q = torch.nn.Linear(dim, self.inner_dim, bias=True)
self.to_k = torch.nn.Linear(dim, self.kv_inner_dim, bias=True)
self.to_v = torch.nn.Linear(dim, self.kv_inner_dim, bias=True)
self.to_out = torch.nn.ModuleList(
[
torch.nn.Linear(self.inner_dim, dim, bias=True),
torch.nn.Dropout(dropout),
]
)
self.norm_q = torch.nn.RMSNorm(dim_head * heads, eps=eps, elementwise_affine=True)
self.norm_k = torch.nn.RMSNorm(dim_head * heads, eps=eps, elementwise_affine=True)

self.add_k_proj = self.add_v_proj = None
if added_kv_proj_dim is not None:
self.add_k_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=True)
self.add_v_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=True)
self.norm_added_k = torch.nn.RMSNorm(dim_head * heads, eps=eps)

self.set_processor(processor)

def fuse_projections(self):
if getattr(self, "fused_projections", False):
return

if self.cross_attention_dim_head is None:
concatenated_weights = torch.cat([self.to_q.weight.data, self.to_k.weight.data, self.to_v.weight.data])
concatenated_bias = torch.cat([self.to_q.bias.data, self.to_k.bias.data, self.to_v.bias.data])
out_features, in_features = concatenated_weights.shape
with torch.device("meta"):
self.to_qkv = nn.Linear(in_features, out_features, bias=True)
self.to_qkv.load_state_dict(
{"weight": concatenated_weights, "bias": concatenated_bias}, strict=True, assign=True
)
else:
concatenated_weights = torch.cat([self.to_k.weight.data, self.to_v.weight.data])
concatenated_bias = torch.cat([self.to_k.bias.data, self.to_v.bias.data])
out_features, in_features = concatenated_weights.shape
with torch.device("meta"):
self.to_kv = nn.Linear(in_features, out_features, bias=True)
self.to_kv.load_state_dict(
{"weight": concatenated_weights, "bias": concatenated_bias}, strict=True, assign=True
)

if self.added_kv_proj_dim is not None:
concatenated_weights = torch.cat([self.add_k_proj.weight.data, self.add_v_proj.weight.data])
concatenated_bias = torch.cat([self.add_k_proj.bias.data, self.add_v_proj.bias.data])
out_features, in_features = concatenated_weights.shape
with torch.device("meta"):
self.to_added_kv = nn.Linear(in_features, out_features, bias=True)
self.to_added_kv.load_state_dict(
{"weight": concatenated_weights, "bias": concatenated_bias}, strict=True, assign=True
)

self.fused_projections = True

@torch.no_grad()
def unfuse_projections(self):
if not getattr(self, "fused_projections", False):
return

if hasattr(self, "to_qkv"):
delattr(self, "to_qkv")
if hasattr(self, "to_kv"):
delattr(self, "to_kv")
if hasattr(self, "to_added_kv"):
delattr(self, "to_added_kv")

self.fused_projections = False

def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
**kwargs,
) -> torch.Tensor:
return self.processor(self, hidden_states, encoder_hidden_states, attention_mask, rotary_emb, **kwargs)


class WanImageEmbedding(torch.nn.Module):
def __init__(self, in_features: int, out_features: int, pos_embed_seq_len=None):
super().__init__()
Expand Down Expand Up @@ -244,8 +394,8 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
freqs_sin_h = freqs_sin[1][:pph].view(1, pph, 1, -1).expand(ppf, pph, ppw, -1)
freqs_sin_w = freqs_sin[2][:ppw].view(1, 1, ppw, -1).expand(ppf, pph, ppw, -1)

freqs_cos = torch.cat([freqs_cos_f, freqs_cos_h, freqs_cos_w], dim=-1).reshape(1, 1, ppf * pph * ppw, -1)
freqs_sin = torch.cat([freqs_sin_f, freqs_sin_h, freqs_sin_w], dim=-1).reshape(1, 1, ppf * pph * ppw, -1)
freqs_cos = torch.cat([freqs_cos_f, freqs_cos_h, freqs_cos_w], dim=-1).reshape(1, ppf * pph * ppw, 1, -1)
freqs_sin = torch.cat([freqs_sin_f, freqs_sin_h, freqs_sin_w], dim=-1).reshape(1, ppf * pph * ppw, 1, -1)

return freqs_cos, freqs_sin

Expand All @@ -266,33 +416,24 @@ def __init__(

# 1. Self-attention
self.norm1 = FP32LayerNorm(dim, eps, elementwise_affine=False)
self.attn1 = Attention(
query_dim=dim,
self.attn1 = WanAttention(
dim=dim,
heads=num_heads,
kv_heads=num_heads,
dim_head=dim // num_heads,
qk_norm=qk_norm,
eps=eps,
bias=True,
cross_attention_dim=None,
out_bias=True,
processor=WanAttnProcessor2_0(),
cross_attention_dim_head=None,
processor=WanAttnProcessor(),
)

# 2. Cross-attention
self.attn2 = Attention(
query_dim=dim,
self.attn2 = WanAttention(
dim=dim,
heads=num_heads,
kv_heads=num_heads,
dim_head=dim // num_heads,
qk_norm=qk_norm,
eps=eps,
bias=True,
cross_attention_dim=None,
out_bias=True,
added_kv_proj_dim=added_kv_proj_dim,
added_proj_bias=True,
processor=WanAttnProcessor2_0(),
cross_attention_dim_head=dim // num_heads,
processor=WanAttnProcessor(),
)
self.norm2 = FP32LayerNorm(dim, eps, elementwise_affine=True) if cross_attn_norm else nn.Identity()

Expand All @@ -315,12 +456,12 @@ def forward(

# 1. Self-attention
norm_hidden_states = (self.norm1(hidden_states.float()) * (1 + scale_msa) + shift_msa).type_as(hidden_states)
attn_output = self.attn1(hidden_states=norm_hidden_states, rotary_emb=rotary_emb)
attn_output = self.attn1(norm_hidden_states, None, None, rotary_emb)
hidden_states = (hidden_states.float() + attn_output * gate_msa).type_as(hidden_states)

# 2. Cross-attention
norm_hidden_states = self.norm2(hidden_states.float()).type_as(hidden_states)
attn_output = self.attn2(hidden_states=norm_hidden_states, encoder_hidden_states=encoder_hidden_states)
attn_output = self.attn2(norm_hidden_states, encoder_hidden_states, None, None)
hidden_states = hidden_states + attn_output

# 3. Feed-forward
Expand All @@ -333,7 +474,9 @@ def forward(
return hidden_states


class WanTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin):
class WanTransformer3DModel(
ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin, AttentionMixin
):
r"""
A Transformer model for video-like data used in the Wan model.

Expand Down
Loading
Loading