Skip to content

Commit 5ffda53

Browse files
committed
feat: fp32 lm_head and fp32 apply_rope options for MoE
Signed-off-by: Hemil Desai <hemild@nvidia.com>
1 parent d195771 commit 5ffda53

File tree

4 files changed

+291
-17
lines changed

4 files changed

+291
-17
lines changed

nemo_automodel/components/models/gpt_oss/rope_utils.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -33,9 +33,10 @@ def apply_rotary_emb(
3333
cos: Cosine tensor (..., rotary_dim // 2)
3434
sin: Sine tensor (..., rotary_dim // 2)
3535
"""
36-
cos = cos.unsqueeze(-2).to(x.dtype)
37-
sin = sin.unsqueeze(-2).to(x.dtype)
38-
36+
cos = cos.unsqueeze(-2)
37+
sin = sin.unsqueeze(-2)
38+
dtype = x.dtype
39+
x = x.to(torch.float32)
3940
# Handle partial rotary embeddings
4041
# cos/sin have dimension rotary_dim//2, so full rotary_dim is cos.shape[-1] * 2
4142
rotary_dim = cos.shape[-1] * 2
@@ -45,13 +46,13 @@ def apply_rotary_emb(
4546
x1, x2 = torch.chunk(x_rot, 2, dim=-1)
4647
o1 = x1 * cos - x2 * sin
4748
o2 = x2 * cos + x1 * sin
48-
return torch.cat((o1, o2, x_pass), dim=-1)
49+
return torch.cat((o1, o2, x_pass), dim=-1).to(dtype)
4950
else:
5051
# Standard full rotary embeddings
5152
x1, x2 = torch.chunk(x, 2, dim=-1)
5253
o1 = x1 * cos - x2 * sin
5354
o2 = x2 * cos + x1 * sin
54-
return torch.cat((o1, o2), dim=-1)
55+
return torch.cat((o1, o2), dim=-1).to(dtype)
5556

5657

5758
class RotaryEmbedding(torch.nn.Module):

nemo_automodel/components/moe/parallelizer.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
GroupedExpertsDeepEP,
3535
MoE,
3636
)
37+
from nemo_automodel.components.moe.utils import BackendConfig
3738

3839
logger = logging.getLogger(__name__)
3940
_CP_STREAM = None
@@ -129,6 +130,8 @@ def apply_fsdp(
129130
mp_policy: MixedPrecisionPolicy | None = None,
130131
offload_policy: OffloadPolicy | None = None,
131132
reshard_after_forward: bool = False,
133+
backend_config: BackendConfig | None = None,
134+
lm_head_precision: torch.dtype | None = None,
132135
):
133136
if mp_policy is None:
134137
mp_policy = MixedPrecisionPolicy(
@@ -175,7 +178,22 @@ def apply_fsdp(
175178

176179
lm_head = getattr(_model, "lm_head", None) or getattr(model, "lm_head", None)
177180
if lm_head is not None:
178-
fully_shard_default(lm_head)
181+
# Use custom mixed precision policy for lm_head if lm_head_precision is specified
182+
if lm_head_precision == torch.float32:
183+
lm_head_mp_policy = MixedPrecisionPolicy(
184+
param_dtype=torch.float32,
185+
reduce_dtype=torch.float32,
186+
output_dtype=torch.float32,
187+
)
188+
fully_shard(
189+
lm_head,
190+
mesh=fsdp_mesh,
191+
reshard_after_forward=reshard_after_forward,
192+
mp_policy=lm_head_mp_policy,
193+
offload_policy=offload_policy,
194+
)
195+
else:
196+
fully_shard_default(lm_head)
179197

180198
fully_shard_default(_model)
181199

@@ -214,6 +232,8 @@ def parallelize_model(
214232
ep_shard_axis_names: tuple[str, ...] | None = None,
215233
activation_checkpointing: bool = False,
216234
reshard_after_forward: bool = False,
235+
backend_config: BackendConfig | None = None,
236+
lm_head_precision: torch.dtype | None = None,
217237
):
218238
assert tp_axis_name is None or world_mesh[tp_axis_name].size() == 1, (
219239
"Tensor parallelism not supported for custom MoE models"
@@ -251,4 +271,6 @@ def parallelize_model(
251271
ep_shard_enabled=ep_shard_mesh is not None and ep_shard_mesh.size() > 1,
252272
ep_shard_mesh=ep_shard_mesh,
253273
reshard_after_forward=reshard_after_forward,
274+
backend_config=backend_config,
275+
lm_head_precision=lm_head_precision,
254276
)

tests/unit_tests/models/gpt_oss/test_rope_utils.py

Lines changed: 139 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,7 @@ def test_partial_rotary_preserves_passthrough(self):
152152
x[..., rotary_dim:],
153153
rtol=0,
154154
atol=0,
155-
msg="Pass-through dimensions should be exactly preserved"
155+
msg="Pass-through dimensions should be exactly preserved",
156156
)
157157

158158
def test_partial_rotary_different_factors(self):
@@ -174,6 +174,135 @@ def test_partial_rotary_different_factors(self):
174174
# Verify pass-through is preserved
175175
torch.testing.assert_close(result[..., rotary_dim:], x_pass)
176176

177+
def test_float32_computation_with_fp16_input(self):
178+
"""Test that computation happens in float32 even with fp16 input"""
179+
batch_size = 2
180+
seq_len = 4
181+
num_heads = 8
182+
head_dim = 64
183+
184+
# Create fp16 input
185+
x_fp16 = torch.randn(batch_size, seq_len, num_heads, head_dim, dtype=torch.float16)
186+
cos = torch.randn(seq_len, head_dim // 2)
187+
sin = torch.randn(seq_len, head_dim // 2)
188+
189+
# Apply rotary embedding
190+
result = apply_rotary_emb(x_fp16, cos, sin)
191+
192+
# Output should be fp16
193+
assert result.dtype == torch.float16
194+
195+
# Compare with fp32 computation for numerical accuracy
196+
x_fp32 = x_fp16.to(torch.float32)
197+
result_fp32 = apply_rotary_emb(x_fp32, cos, sin)
198+
199+
# The fp16 result should be close to the fp32 result when cast to fp32
200+
torch.testing.assert_close(result.to(torch.float32), result_fp32, rtol=1e-3, atol=1e-3)
201+
202+
def test_float32_computation_with_bfloat16_input(self):
203+
"""Test that computation happens in float32 even with bfloat16 input"""
204+
batch_size = 2
205+
seq_len = 4
206+
num_heads = 8
207+
head_dim = 64
208+
209+
# Create bfloat16 input
210+
x_bf16 = torch.randn(batch_size, seq_len, num_heads, head_dim, dtype=torch.bfloat16)
211+
cos = torch.randn(seq_len, head_dim // 2)
212+
sin = torch.randn(seq_len, head_dim // 2)
213+
214+
# Apply rotary embedding
215+
result = apply_rotary_emb(x_bf16, cos, sin)
216+
217+
# Output should be bfloat16
218+
assert result.dtype == torch.bfloat16
219+
220+
# Compare with fp32 computation for numerical accuracy
221+
x_fp32 = x_bf16.to(torch.float32)
222+
result_fp32 = apply_rotary_emb(x_fp32, cos, sin)
223+
224+
# The bf16 result should be close to the fp32 result when cast to fp32
225+
torch.testing.assert_close(result.to(torch.float32), result_fp32, rtol=1e-2, atol=1e-2)
226+
227+
def test_cos_sin_dtype_independence(self):
228+
"""Test that cos/sin dtype doesn't affect output dtype"""
229+
batch_size = 2
230+
seq_len = 4
231+
num_heads = 8
232+
head_dim = 64
233+
234+
x = torch.randn(batch_size, seq_len, num_heads, head_dim, dtype=torch.float16)
235+
236+
# Test with different cos/sin dtypes
237+
for cos_sin_dtype in [torch.float32, torch.float16, torch.bfloat16]:
238+
cos = torch.randn(seq_len, head_dim // 2, dtype=cos_sin_dtype)
239+
sin = torch.randn(seq_len, head_dim // 2, dtype=cos_sin_dtype)
240+
241+
result = apply_rotary_emb(x, cos, sin)
242+
243+
# Output dtype should match input x dtype, not cos/sin dtype
244+
assert result.dtype == x.dtype
245+
246+
def test_partial_rotary_float32_computation_with_fp16(self):
247+
"""Test that partial rotary also uses float32 computation with fp16 input"""
248+
batch_size = 2
249+
seq_len = 4
250+
num_heads = 8
251+
head_dim = 64
252+
rotary_dim = 32 # Only rotate half the dimensions
253+
254+
# Create fp16 input
255+
x_fp16 = torch.randn(batch_size, seq_len, num_heads, head_dim, dtype=torch.float16)
256+
cos = torch.randn(seq_len, rotary_dim // 2)
257+
sin = torch.randn(seq_len, rotary_dim // 2)
258+
259+
# Store the pass-through part
260+
x_pass_original = x_fp16[..., rotary_dim:].clone()
261+
262+
# Apply rotary embedding
263+
result = apply_rotary_emb(x_fp16, cos, sin)
264+
265+
# Output should be fp16
266+
assert result.dtype == torch.float16
267+
268+
# Pass-through dimensions should be exactly preserved (no dtype conversion artifacts)
269+
torch.testing.assert_close(result[..., rotary_dim:], x_pass_original, rtol=0, atol=0)
270+
271+
# Compare with fp32 computation
272+
x_fp32 = x_fp16.to(torch.float32)
273+
result_fp32 = apply_rotary_emb(x_fp32, cos, sin)
274+
275+
# The rotated part should be close to fp32 computation
276+
torch.testing.assert_close(
277+
result[..., :rotary_dim].to(torch.float32), result_fp32[..., :rotary_dim], rtol=1e-3, atol=1e-3
278+
)
279+
280+
def test_numerical_stability_with_mixed_dtypes(self):
281+
"""Test numerical stability when x, cos, sin have different dtypes"""
282+
batch_size = 2
283+
seq_len = 4
284+
num_heads = 8
285+
head_dim = 64
286+
287+
# Test various dtype combinations
288+
dtype_combinations = [
289+
(torch.float16, torch.float32),
290+
(torch.bfloat16, torch.float32),
291+
(torch.float32, torch.float16),
292+
]
293+
294+
for x_dtype, cos_sin_dtype in dtype_combinations:
295+
x = torch.randn(batch_size, seq_len, num_heads, head_dim, dtype=x_dtype)
296+
cos = torch.randn(seq_len, head_dim // 2, dtype=cos_sin_dtype)
297+
sin = torch.randn(seq_len, head_dim // 2, dtype=cos_sin_dtype)
298+
299+
# Should not raise any errors
300+
result = apply_rotary_emb(x, cos, sin)
301+
302+
# Output should match input x dtype
303+
assert result.dtype == x_dtype
304+
assert result.shape == x.shape
305+
177306

178307
class TestRotaryEmbedding:
179308
"""Tests for RotaryEmbedding class"""
@@ -536,11 +665,13 @@ def test_different_batch_patterns(self):
536665
dtype=torch.float32,
537666
)
538667

539-
position_ids = torch.tensor([
540-
[0, 1, 2, 3], # Sequential
541-
[0, 0, 1, 1], # Repeated
542-
[10, 20, 30, 40], # Large gaps
543-
])
668+
position_ids = torch.tensor(
669+
[
670+
[0, 1, 2, 3], # Sequential
671+
[0, 0, 1, 1], # Repeated
672+
[10, 20, 30, 40], # Large gaps
673+
]
674+
)
544675

545676
freqs_cis = position_ids_to_freqs_cis(rope, position_ids, qkv_format="bshd")
546677

@@ -601,10 +732,7 @@ def test_freqs_cis_consistency_across_ranks(self, cp_size, cp_rank):
601732
if len(indices) > 1:
602733
# All tokens at this position should have identical freqs_cis
603734
for i in range(1, len(indices)):
604-
torch.testing.assert_close(
605-
freqs_cis_rank[indices[0]],
606-
freqs_cis_rank[indices[i]]
607-
)
735+
torch.testing.assert_close(freqs_cis_rank[indices[0]], freqs_cis_rank[indices[i]])
608736

609737
def test_freqs_cis_cp_with_variable_sequence_lengths(self):
610738
"""Test freqs_cis with variable-length sequences and CP splitting"""
@@ -697,7 +825,7 @@ def test_full_rope_pipeline(self):
697825
# Step 2: Extract cos and sin from freqs_cis
698826
# freqs_cis contains concatenated cos and sin
699827
cos = freqs_cis[..., :32] # First half is cos
700-
sin = freqs_cis[..., 32:] # Second half is sin
828+
sin = freqs_cis[..., 32:] # Second half is sin
701829

702830
# Step 3: Apply rotary embeddings
703831
x = torch.randn(batch_size, seq_len, num_heads, 64)

0 commit comments

Comments
 (0)