Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 6 additions & 5 deletions nemo_automodel/components/models/gpt_oss/rope_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,10 @@ def apply_rotary_emb(
cos: Cosine tensor (..., rotary_dim // 2)
sin: Sine tensor (..., rotary_dim // 2)
"""
cos = cos.unsqueeze(-2).to(x.dtype)
sin = sin.unsqueeze(-2).to(x.dtype)

cos = cos.unsqueeze(-2)
sin = sin.unsqueeze(-2)
dtype = x.dtype
x = x.to(torch.float32)
# Handle partial rotary embeddings
# cos/sin have dimension rotary_dim//2, so full rotary_dim is cos.shape[-1] * 2
rotary_dim = cos.shape[-1] * 2
Expand All @@ -45,13 +46,13 @@ def apply_rotary_emb(
x1, x2 = torch.chunk(x_rot, 2, dim=-1)
o1 = x1 * cos - x2 * sin
o2 = x2 * cos + x1 * sin
return torch.cat((o1, o2, x_pass), dim=-1)
return torch.cat((o1, o2, x_pass), dim=-1).to(dtype)
else:
# Standard full rotary embeddings
x1, x2 = torch.chunk(x, 2, dim=-1)
o1 = x1 * cos - x2 * sin
o2 = x2 * cos + x1 * sin
return torch.cat((o1, o2), dim=-1)
return torch.cat((o1, o2), dim=-1).to(dtype)


class RotaryEmbedding(torch.nn.Module):
Expand Down
24 changes: 23 additions & 1 deletion nemo_automodel/components/moe/parallelizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
GroupedExpertsDeepEP,
MoE,
)
from nemo_automodel.shared.utils import dtype_from_str

logger = logging.getLogger(__name__)
_CP_STREAM = None
Expand Down Expand Up @@ -129,7 +130,11 @@ def apply_fsdp(
mp_policy: MixedPrecisionPolicy | None = None,
offload_policy: OffloadPolicy | None = None,
reshard_after_forward: bool = False,
lm_head_precision: str | torch.dtype | None = None,
):
if isinstance(lm_head_precision, str):
lm_head_precision = dtype_from_str(lm_head_precision, default=None)

if mp_policy is None:
mp_policy = MixedPrecisionPolicy(
param_dtype=torch.bfloat16, reduce_dtype=torch.float32, output_dtype=torch.bfloat16
Expand Down Expand Up @@ -175,7 +180,22 @@ def apply_fsdp(

lm_head = getattr(_model, "lm_head", None) or getattr(model, "lm_head", None)
if lm_head is not None:
fully_shard_default(lm_head)
# Use custom mixed precision policy for lm_head if lm_head_precision is specified
if lm_head_precision == torch.float32:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is it possible to inspect the lm_head to figure out the precision?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This option is to force lm head in fp32 regardless of checkpoint dtype. fp32 lm_head helps with RL stability.

lm_head_mp_policy = MixedPrecisionPolicy(
param_dtype=torch.float32,
reduce_dtype=torch.float32,
output_dtype=torch.float32,
)
fully_shard(
lm_head,
mesh=fsdp_mesh,
reshard_after_forward=reshard_after_forward,
mp_policy=lm_head_mp_policy,
offload_policy=offload_policy,
)
else:
fully_shard_default(lm_head)

fully_shard_default(_model)

Expand Down Expand Up @@ -214,6 +234,7 @@ def parallelize_model(
ep_shard_axis_names: tuple[str, ...] | None = None,
activation_checkpointing: bool = False,
reshard_after_forward: bool = False,
lm_head_precision: str | torch.dtype | None = None,
):
assert tp_axis_name is None or world_mesh[tp_axis_name].size() == 1, (
"Tensor parallelism not supported for custom MoE models"
Expand Down Expand Up @@ -251,4 +272,5 @@ def parallelize_model(
ep_shard_enabled=ep_shard_mesh is not None and ep_shard_mesh.size() > 1,
ep_shard_mesh=ep_shard_mesh,
reshard_after_forward=reshard_after_forward,
lm_head_precision=lm_head_precision,
)
150 changes: 139 additions & 11 deletions tests/unit_tests/models/gpt_oss/test_rope_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ def test_partial_rotary_preserves_passthrough(self):
x[..., rotary_dim:],
rtol=0,
atol=0,
msg="Pass-through dimensions should be exactly preserved"
msg="Pass-through dimensions should be exactly preserved",
)

def test_partial_rotary_different_factors(self):
Expand All @@ -174,6 +174,135 @@ def test_partial_rotary_different_factors(self):
# Verify pass-through is preserved
torch.testing.assert_close(result[..., rotary_dim:], x_pass)

def test_float32_computation_with_fp16_input(self):
"""Test that computation happens in float32 even with fp16 input"""
batch_size = 2
seq_len = 4
num_heads = 8
head_dim = 64

# Create fp16 input
x_fp16 = torch.randn(batch_size, seq_len, num_heads, head_dim, dtype=torch.float16)
cos = torch.randn(seq_len, head_dim // 2)
sin = torch.randn(seq_len, head_dim // 2)

# Apply rotary embedding
result = apply_rotary_emb(x_fp16, cos, sin)

# Output should be fp16
assert result.dtype == torch.float16

# Compare with fp32 computation for numerical accuracy
x_fp32 = x_fp16.to(torch.float32)
result_fp32 = apply_rotary_emb(x_fp32, cos, sin)

# The fp16 result should be close to the fp32 result when cast to fp32
torch.testing.assert_close(result.to(torch.float32), result_fp32, rtol=1e-3, atol=1e-3)

def test_float32_computation_with_bfloat16_input(self):
"""Test that computation happens in float32 even with bfloat16 input"""
batch_size = 2
seq_len = 4
num_heads = 8
head_dim = 64

# Create bfloat16 input
x_bf16 = torch.randn(batch_size, seq_len, num_heads, head_dim, dtype=torch.bfloat16)
cos = torch.randn(seq_len, head_dim // 2)
sin = torch.randn(seq_len, head_dim // 2)

# Apply rotary embedding
result = apply_rotary_emb(x_bf16, cos, sin)

# Output should be bfloat16
assert result.dtype == torch.bfloat16

# Compare with fp32 computation for numerical accuracy
x_fp32 = x_bf16.to(torch.float32)
result_fp32 = apply_rotary_emb(x_fp32, cos, sin)

# The bf16 result should be close to the fp32 result when cast to fp32
torch.testing.assert_close(result.to(torch.float32), result_fp32, rtol=1e-2, atol=1e-2)

def test_cos_sin_dtype_independence(self):
"""Test that cos/sin dtype doesn't affect output dtype"""
batch_size = 2
seq_len = 4
num_heads = 8
head_dim = 64

x = torch.randn(batch_size, seq_len, num_heads, head_dim, dtype=torch.float16)

# Test with different cos/sin dtypes
for cos_sin_dtype in [torch.float32, torch.float16, torch.bfloat16]:
cos = torch.randn(seq_len, head_dim // 2, dtype=cos_sin_dtype)
sin = torch.randn(seq_len, head_dim // 2, dtype=cos_sin_dtype)

result = apply_rotary_emb(x, cos, sin)

# Output dtype should match input x dtype, not cos/sin dtype
assert result.dtype == x.dtype

def test_partial_rotary_float32_computation_with_fp16(self):
"""Test that partial rotary also uses float32 computation with fp16 input"""
batch_size = 2
seq_len = 4
num_heads = 8
head_dim = 64
rotary_dim = 32 # Only rotate half the dimensions

# Create fp16 input
x_fp16 = torch.randn(batch_size, seq_len, num_heads, head_dim, dtype=torch.float16)
cos = torch.randn(seq_len, rotary_dim // 2)
sin = torch.randn(seq_len, rotary_dim // 2)

# Store the pass-through part
x_pass_original = x_fp16[..., rotary_dim:].clone()

# Apply rotary embedding
result = apply_rotary_emb(x_fp16, cos, sin)

# Output should be fp16
assert result.dtype == torch.float16

# Pass-through dimensions should be exactly preserved (no dtype conversion artifacts)
torch.testing.assert_close(result[..., rotary_dim:], x_pass_original, rtol=0, atol=0)

# Compare with fp32 computation
x_fp32 = x_fp16.to(torch.float32)
result_fp32 = apply_rotary_emb(x_fp32, cos, sin)

# The rotated part should be close to fp32 computation
torch.testing.assert_close(
result[..., :rotary_dim].to(torch.float32), result_fp32[..., :rotary_dim], rtol=1e-3, atol=1e-3
)

def test_numerical_stability_with_mixed_dtypes(self):
"""Test numerical stability when x, cos, sin have different dtypes"""
batch_size = 2
seq_len = 4
num_heads = 8
head_dim = 64

# Test various dtype combinations
dtype_combinations = [
(torch.float16, torch.float32),
(torch.bfloat16, torch.float32),
(torch.float32, torch.float16),
]

for x_dtype, cos_sin_dtype in dtype_combinations:
x = torch.randn(batch_size, seq_len, num_heads, head_dim, dtype=x_dtype)
cos = torch.randn(seq_len, head_dim // 2, dtype=cos_sin_dtype)
sin = torch.randn(seq_len, head_dim // 2, dtype=cos_sin_dtype)

# Should not raise any errors
result = apply_rotary_emb(x, cos, sin)

# Output should match input x dtype
assert result.dtype == x_dtype
assert result.shape == x.shape


class TestRotaryEmbedding:
"""Tests for RotaryEmbedding class"""
Expand Down Expand Up @@ -536,11 +665,13 @@ def test_different_batch_patterns(self):
dtype=torch.float32,
)

position_ids = torch.tensor([
[0, 1, 2, 3], # Sequential
[0, 0, 1, 1], # Repeated
[10, 20, 30, 40], # Large gaps
])
position_ids = torch.tensor(
[
[0, 1, 2, 3], # Sequential
[0, 0, 1, 1], # Repeated
[10, 20, 30, 40], # Large gaps
]
)

freqs_cis = position_ids_to_freqs_cis(rope, position_ids, qkv_format="bshd")

Expand Down Expand Up @@ -601,10 +732,7 @@ def test_freqs_cis_consistency_across_ranks(self, cp_size, cp_rank):
if len(indices) > 1:
# All tokens at this position should have identical freqs_cis
for i in range(1, len(indices)):
torch.testing.assert_close(
freqs_cis_rank[indices[0]],
freqs_cis_rank[indices[i]]
)
torch.testing.assert_close(freqs_cis_rank[indices[0]], freqs_cis_rank[indices[i]])

def test_freqs_cis_cp_with_variable_sequence_lengths(self):
"""Test freqs_cis with variable-length sequences and CP splitting"""
Expand Down Expand Up @@ -697,7 +825,7 @@ def test_full_rope_pipeline(self):
# Step 2: Extract cos and sin from freqs_cis
# freqs_cis contains concatenated cos and sin
cos = freqs_cis[..., :32] # First half is cos
sin = freqs_cis[..., 32:] # Second half is sin
sin = freqs_cis[..., 32:] # Second half is sin

# Step 3: Apply rotary embeddings
x = torch.randn(batch_size, seq_len, num_heads, 64)
Expand Down
Loading
Loading