Fix rotary positional embedding dimension mismatch in Wan and SkyReels V2 transformers #12594
+11
−10
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Fixes #12538
The original implementation had inconsistent formulas between init and forward:
In init: Dimensions calculated as h_dim = w_dim = 2 * (attention_head_dim // 6)
In forward: Split sizes calculated as attention_head_dim // 3
While mathematically similar, these can produce different results due to integer division, leading to potential tensor dimension mismatches when splitting the rotary embedding buffers.
Solution
Store the computed dimensions (t_dim, h_dim, w_dim) as instance variables in init and reuse them in forward().
Before submitting
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
@Warvito, @yiyixuxu
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.