-
Notifications
You must be signed in to change notification settings - Fork 281
[Bugfix]: Correct handling of cos_sin_cache length #1900
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: Jade Zheng <zheng.shoujian@outlook.com>
@whx-sjtu PTAL |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR fixes performance issues in rotary embedding cos/sin cache handling by correcting variable usage and preventing unnecessary cache recreation. The fix ensures that the cache, which is already initialized with maximum context length, is not unnecessarily recreated during processing.
- Replaces cache recreation logic with an error when max_seq_len exceeds the initialized maximum
- Corrects variable assignment in
_set_cos_sin_cache
frommax_seq_len_cached
tomax_seq_len
- Removes redundant
max_seq_len
assignment during initialization since the cache setup handles this
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I reviewed codes of branch v0.9.1-dev and found that this problem has already been solved in that branch while hasn't been ported to main. Thanks for finding and fixing this. LGTM.
…alid inputs Signed-off-by: Jade Zheng <zheng.shoujian@outlook.com>
Signed-off-by: Jade Zheng <zheng.shoujian@outlook.com>
Codecov ReportAttention: Patch coverage is
❌ Your patch check has failed because the patch coverage (75.00%) is below the target coverage (100.00%). You can increase the patch coverage or adjust the target coverage. Additional details and impacted files@@ Coverage Diff @@
## main #1900 +/- ##
==========================================
+ Coverage 60.17% 60.21% +0.04%
==========================================
Files 71 71
Lines 7989 7995 +6
==========================================
+ Hits 4807 4814 +7
+ Misses 3182 3181 -1
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
@whx-sjtu Which PR fixed this issue in the 0.9.1-dev branch? |
@@ -209,7 +211,7 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): | |||
|
|||
|
|||
def _set_cos_sin_cache(self, seq_len, device, dtype): | |||
self.max_seq_len_cached = seq_len | |||
self.max_seq_len = seq_len * self.scaling_factor |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is no problem in v0.9.1. what happens.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
#1551 fixed this problem in v0.9.1-dev
What this PR does / why we need it?
This PR addresses the performance issue related to cos/sin cache handling:
The cos/sin cache is already initialized with the maximum context length during initialization. However, due to
max_seq_len_cache
being stored asseq_len
, the condition check was incorrect, leading to unnecessary cache recreation.Since the cos/sin cache is already initialized with maximum context length, it should not trigger recreation during the process.
Fixed variable naming:
max_seq_len_cache
was never used and should bemax_seq_len
. This also is the correct variable to check against the maximum context length.Does this PR introduce any user-facing change?
No
How was this patch tested?
CI pass.