Skip to content

Commit 3a31b29

Browse files
authored
Use float32 RoPE freqs in Wan with MPS backends (#11643)
Use float32 for RoPE on MPS in Wan
1 parent b975bce commit 3a31b29

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed

src/diffusers/models/transformers/transformer_wan.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,8 @@ def __call__(
7272
if rotary_emb is not None:
7373

7474
def apply_rotary_emb(hidden_states: torch.Tensor, freqs: torch.Tensor):
75-
x_rotated = torch.view_as_complex(hidden_states.to(torch.float64).unflatten(3, (-1, 2)))
75+
dtype = torch.float32 if hidden_states.device.type == "mps" else torch.float64
76+
x_rotated = torch.view_as_complex(hidden_states.to(dtype).unflatten(3, (-1, 2)))
7677
x_out = torch.view_as_real(x_rotated * freqs).flatten(3, 4)
7778
return x_out.type_as(hidden_states)
7879

@@ -190,9 +191,10 @@ def __init__(
190191
t_dim = attention_head_dim - h_dim - w_dim
191192

192193
freqs = []
194+
freqs_dtype = torch.float32 if torch.backends.mps.is_available() else torch.float64
193195
for dim in [t_dim, h_dim, w_dim]:
194196
freq = get_1d_rotary_pos_embed(
195-
dim, max_seq_len, theta, use_real=False, repeat_interleave_real=False, freqs_dtype=torch.float64
197+
dim, max_seq_len, theta, use_real=False, repeat_interleave_real=False, freqs_dtype=freqs_dtype
196198
)
197199
freqs.append(freq)
198200
self.freqs = torch.cat(freqs, dim=1)

0 commit comments

Comments
 (0)