Skip to content

Misallignment in Rotary Frequencies #12538

@Warvito

Description

@Warvito

Describe the bug

In WanRotaryPosEmbed (link), we are splitting the attention_head_dim into the different dimensions in different ways in __init__ and forward. This causes a missmatch depending on the attention_head_dim. This issue is also presentin other models that use rotary (e.g., Skyreels_v2).

Details

Given

h_dim = w_dim = 2 * (attention_head_dim // 6)
t_dim = attention_head_dim - h_dim - w_dim

if we have an attention_head_dim=64, we get:

h_dim = w_dim = 2 * (attention_head_dim // 6)
t_dim = attention_head_dim - h_dim - w_dim
print([t_dim, h_dim, w_dim])

printing [24, 20, 20]

In the forward, when spliting the dimensions, we have

split_sizes = [
self.attention_head_dim - 2 * (self.attention_head_dim // 3),
self.attention_head_dim // 3,
self.attention_head_dim // 3,
]

so if we try

split_sizes = [
    attention_head_dim - 2 * (attention_head_dim // 3),
    attention_head_dim // 3,
    attention_head_dim // 3,
]
print(split_sizes)

printing [22, 21, 21]

In most of the models where the attention head is equal to 128 the values match, but I was wondering if this is a bug to fix.

Reproduction

NA

Logs

System Info

NA

Who can help?

@DN6 @yiyixuxu @sayakpaul

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions