Skip to content

Use real-valued instead of complex tensors in Wan2.1 RoPE #11649

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 65 additions & 25 deletions src/diffusers/models/transformers/transformer_wan.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,11 +71,24 @@ def __call__(

if rotary_emb is not None:

def apply_rotary_emb(hidden_states: torch.Tensor, freqs: torch.Tensor):
dtype = torch.float32 if hidden_states.device.type == "mps" else torch.float64
x_rotated = torch.view_as_complex(hidden_states.to(dtype).unflatten(3, (-1, 2)))
x_out = torch.view_as_real(x_rotated * freqs).flatten(3, 4)
return x_out.type_as(hidden_states)
def apply_rotary_emb(
hidden_states: torch.Tensor,
freqs_cos: torch.Tensor,
freqs_sin: torch.Tensor,
):
dtype = (
torch.float32
if hidden_states.device.type == "mps"
else torch.float64
)
x = hidden_states.view(*hidden_states.shape[:-1], -1, 2).to(dtype)
x1, x2 = x[..., 0], x[..., 1]
cos = freqs_cos[..., 0::2]
sin = freqs_sin[..., 1::2]
out = torch.empty_like(hidden_states)
out[..., 0::2] = x1 * cos - x2 * sin
out[..., 1::2] = x1 * sin + x2 * cos
return out

query = apply_rotary_emb(query, rotary_emb)
key = apply_rotary_emb(key, rotary_emb)
Expand Down Expand Up @@ -179,7 +192,11 @@ def forward(

class WanRotaryPosEmbed(nn.Module):
def __init__(
self, attention_head_dim: int, patch_size: Tuple[int, int, int], max_seq_len: int, theta: float = 10000.0
self,
attention_head_dim: int,
patch_size: Tuple[int, int, int],
max_seq_len: int,
theta: float = 10000.0,
):
super().__init__()

Expand All @@ -189,36 +206,59 @@ def __init__(

h_dim = w_dim = 2 * (attention_head_dim // 6)
t_dim = attention_head_dim - h_dim - w_dim

freqs = []
freqs_dtype = torch.float32 if torch.backends.mps.is_available() else torch.float64

freqs_cos = []
freqs_sin = []

for dim in [t_dim, h_dim, w_dim]:
freq = get_1d_rotary_pos_embed(
dim, max_seq_len, theta, use_real=False, repeat_interleave_real=False, freqs_dtype=freqs_dtype
freq_cos, freq_sin = get_1d_rotary_pos_embed(
dim,
max_seq_len,
theta,
use_real=True,
repeat_interleave_real=True,
freqs_dtype=freqs_dtype,
)
freqs.append(freq)
self.freqs = torch.cat(freqs, dim=1)
freqs_cos.append(freq_cos)
freqs_sin.append(freq_sin)

self.freqs_cos = torch.cat(freqs_cos, dim=1)
self.freqs_sin = torch.cat(freqs_sin, dim=1)

def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
batch_size, num_channels, num_frames, height, width = hidden_states.shape
p_t, p_h, p_w = self.patch_size
ppf, pph, ppw = num_frames // p_t, height // p_h, width // p_w

freqs = self.freqs.to(hidden_states.device)
freqs = freqs.split_with_sizes(
[
self.attention_head_dim // 2 - 2 * (self.attention_head_dim // 6),
self.attention_head_dim // 6,
self.attention_head_dim // 6,
],
dim=1,
self.freqs_cos = self.freqs_cos.to(hidden_states.device)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think doing it this way will cause a recompilation. We could probably just store as non-persistent buffer though with this refactor. The reason for not using buffer before was because it was a complex tensor

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good thinking! I have added the proposed changes now.

self.freqs_sin = self.freqs_sin.to(hidden_states.device)

split_sizes = [
self.attention_head_dim - 2 * (self.attention_head_dim // 3),
self.attention_head_dim // 3,
self.attention_head_dim // 3,
]

freqs_cos = self.freqs_cos.split(split_sizes, dim=1)
freqs_sin = self.freqs_sin.split(split_sizes, dim=1)

freqs_cos_f = freqs_cos[0][:ppf].view(ppf, 1, 1, -1).expand(ppf, pph, ppw, -1)
freqs_cos_h = freqs_cos[1][:pph].view(1, pph, 1, -1).expand(ppf, pph, ppw, -1)
freqs_cos_w = freqs_cos[2][:ppw].view(1, 1, ppw, -1).expand(ppf, pph, ppw, -1)

freqs_sin_f = freqs_sin[0][:ppf].view(ppf, 1, 1, -1).expand(ppf, pph, ppw, -1)
freqs_sin_h = freqs_sin[1][:pph].view(1, pph, 1, -1).expand(ppf, pph, ppw, -1)
freqs_sin_w = freqs_sin[2][:ppw].view(1, 1, ppw, -1).expand(ppf, pph, ppw, -1)

freqs_cos = torch.cat([freqs_cos_f, freqs_cos_h, freqs_cos_w], dim=-1).reshape(
1, 1, ppf * pph * ppw, -1
)
freqs_sin = torch.cat([freqs_sin_f, freqs_sin_h, freqs_sin_w], dim=-1).reshape(
1, 1, ppf * pph * ppw, -1
)

freqs_f = freqs[0][:ppf].view(ppf, 1, 1, -1).expand(ppf, pph, ppw, -1)
freqs_h = freqs[1][:pph].view(1, pph, 1, -1).expand(ppf, pph, ppw, -1)
freqs_w = freqs[2][:ppw].view(1, 1, ppw, -1).expand(ppf, pph, ppw, -1)
freqs = torch.cat([freqs_f, freqs_h, freqs_w], dim=-1).reshape(1, 1, ppf * pph * ppw, -1)
return freqs
return freqs_cos, freqs_sin


class WanTransformerBlock(nn.Module):
Expand Down