-
Notifications
You must be signed in to change notification settings - Fork 19
feat: fp32 lm_head and fp32 apply_rope options for MoE #769
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
| if lm_head is not None: | ||
| fully_shard_default(lm_head) | ||
| # Use custom mixed precision policy for lm_head if lm_head_precision is specified | ||
| if lm_head_precision == torch.float32: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is it possible to inspect the lm_head to figure out the precision?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This option is to force lm head in fp32 regardless of checkpoint dtype. fp32 lm_head helps with RL stability.
Signed-off-by: Hemil Desai <hemild@nvidia.com>
3579886 to
945039c
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
|
/ok to test 2e79028 |
No description provided.