Skip to content

Conversation

@hemildesai
Copy link
Contributor

No description provided.

if lm_head is not None:
fully_shard_default(lm_head)
# Use custom mixed precision policy for lm_head if lm_head_precision is specified
if lm_head_precision == torch.float32:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is it possible to inspect the lm_head to figure out the precision?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This option is to force lm head in fp32 regardless of checkpoint dtype. fp32 lm_head helps with RL stability.

Signed-off-by: Hemil Desai <hemild@nvidia.com>
Signed-off-by: Hemil Desai <hemild@nvidia.com>
@hemildesai hemildesai force-pushed the hemil/fp32-lmhead-rope branch from 3579886 to 945039c Compare November 5, 2025 21:32
@copy-pr-bot
Copy link

copy-pr-bot bot commented Nov 5, 2025

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

Copy link
Collaborator

@adil-a adil-a left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@adil-a
Copy link
Collaborator

adil-a commented Nov 6, 2025

/ok to test 2e79028

@adil-a adil-a enabled auto-merge (squash) November 6, 2025 17:01
@adil-a adil-a merged commit 8316227 into main Nov 6, 2025
51 checks passed
@adil-a adil-a deleted the hemil/fp32-lmhead-rope branch November 6, 2025 20:37
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants