diff --git a/nemo_automodel/components/models/gpt_oss/rope_utils.py b/nemo_automodel/components/models/gpt_oss/rope_utils.py index 5177f0220..42e6b4396 100644 --- a/nemo_automodel/components/models/gpt_oss/rope_utils.py +++ b/nemo_automodel/components/models/gpt_oss/rope_utils.py @@ -33,9 +33,10 @@ def apply_rotary_emb( cos: Cosine tensor (..., rotary_dim // 2) sin: Sine tensor (..., rotary_dim // 2) """ - cos = cos.unsqueeze(-2).to(x.dtype) - sin = sin.unsqueeze(-2).to(x.dtype) - + cos = cos.unsqueeze(-2) + sin = sin.unsqueeze(-2) + dtype = x.dtype + x = x.to(torch.float32) # Handle partial rotary embeddings # cos/sin have dimension rotary_dim//2, so full rotary_dim is cos.shape[-1] * 2 rotary_dim = cos.shape[-1] * 2 @@ -45,13 +46,13 @@ def apply_rotary_emb( x1, x2 = torch.chunk(x_rot, 2, dim=-1) o1 = x1 * cos - x2 * sin o2 = x2 * cos + x1 * sin - return torch.cat((o1, o2, x_pass), dim=-1) + return torch.cat((o1, o2, x_pass), dim=-1).to(dtype) else: # Standard full rotary embeddings x1, x2 = torch.chunk(x, 2, dim=-1) o1 = x1 * cos - x2 * sin o2 = x2 * cos + x1 * sin - return torch.cat((o1, o2), dim=-1) + return torch.cat((o1, o2), dim=-1).to(dtype) class RotaryEmbedding(torch.nn.Module): diff --git a/nemo_automodel/components/moe/parallelizer.py b/nemo_automodel/components/moe/parallelizer.py index 4ff7ac085..1eec0e54a 100644 --- a/nemo_automodel/components/moe/parallelizer.py +++ b/nemo_automodel/components/moe/parallelizer.py @@ -34,6 +34,7 @@ GroupedExpertsDeepEP, MoE, ) +from nemo_automodel.shared.utils import dtype_from_str logger = logging.getLogger(__name__) _CP_STREAM = None @@ -129,7 +130,11 @@ def apply_fsdp( mp_policy: MixedPrecisionPolicy | None = None, offload_policy: OffloadPolicy | None = None, reshard_after_forward: bool = False, + lm_head_precision: str | torch.dtype | None = None, ): + if isinstance(lm_head_precision, str): + lm_head_precision = dtype_from_str(lm_head_precision, default=None) + if mp_policy is None: mp_policy = MixedPrecisionPolicy( param_dtype=torch.bfloat16, reduce_dtype=torch.float32, output_dtype=torch.bfloat16 @@ -175,7 +180,22 @@ def apply_fsdp( lm_head = getattr(_model, "lm_head", None) or getattr(model, "lm_head", None) if lm_head is not None: - fully_shard_default(lm_head) + # Use custom mixed precision policy for lm_head if lm_head_precision is specified + if lm_head_precision == torch.float32: + lm_head_mp_policy = MixedPrecisionPolicy( + param_dtype=torch.float32, + reduce_dtype=torch.float32, + output_dtype=torch.float32, + ) + fully_shard( + lm_head, + mesh=fsdp_mesh, + reshard_after_forward=reshard_after_forward, + mp_policy=lm_head_mp_policy, + offload_policy=offload_policy, + ) + else: + fully_shard_default(lm_head) fully_shard_default(_model) @@ -214,6 +234,7 @@ def parallelize_model( ep_shard_axis_names: tuple[str, ...] | None = None, activation_checkpointing: bool = False, reshard_after_forward: bool = False, + lm_head_precision: str | torch.dtype | None = None, ): assert tp_axis_name is None or world_mesh[tp_axis_name].size() == 1, ( "Tensor parallelism not supported for custom MoE models" @@ -251,4 +272,5 @@ def parallelize_model( ep_shard_enabled=ep_shard_mesh is not None and ep_shard_mesh.size() > 1, ep_shard_mesh=ep_shard_mesh, reshard_after_forward=reshard_after_forward, + lm_head_precision=lm_head_precision, ) diff --git a/tests/unit_tests/models/gpt_oss/test_rope_utils.py b/tests/unit_tests/models/gpt_oss/test_rope_utils.py index 2becb7339..b21be6a85 100644 --- a/tests/unit_tests/models/gpt_oss/test_rope_utils.py +++ b/tests/unit_tests/models/gpt_oss/test_rope_utils.py @@ -152,7 +152,7 @@ def test_partial_rotary_preserves_passthrough(self): x[..., rotary_dim:], rtol=0, atol=0, - msg="Pass-through dimensions should be exactly preserved" + msg="Pass-through dimensions should be exactly preserved", ) def test_partial_rotary_different_factors(self): @@ -174,6 +174,135 @@ def test_partial_rotary_different_factors(self): # Verify pass-through is preserved torch.testing.assert_close(result[..., rotary_dim:], x_pass) + def test_float32_computation_with_fp16_input(self): + """Test that computation happens in float32 even with fp16 input""" + batch_size = 2 + seq_len = 4 + num_heads = 8 + head_dim = 64 + + # Create fp16 input + x_fp16 = torch.randn(batch_size, seq_len, num_heads, head_dim, dtype=torch.float16) + cos = torch.randn(seq_len, head_dim // 2) + sin = torch.randn(seq_len, head_dim // 2) + + # Apply rotary embedding + result = apply_rotary_emb(x_fp16, cos, sin) + + # Output should be fp16 + assert result.dtype == torch.float16 + + # Compare with fp32 computation for numerical accuracy + x_fp32 = x_fp16.to(torch.float32) + result_fp32 = apply_rotary_emb(x_fp32, cos, sin) + + # The fp16 result should be close to the fp32 result when cast to fp32 + torch.testing.assert_close(result.to(torch.float32), result_fp32, rtol=1e-3, atol=1e-3) + + def test_float32_computation_with_bfloat16_input(self): + """Test that computation happens in float32 even with bfloat16 input""" + batch_size = 2 + seq_len = 4 + num_heads = 8 + head_dim = 64 + + # Create bfloat16 input + x_bf16 = torch.randn(batch_size, seq_len, num_heads, head_dim, dtype=torch.bfloat16) + cos = torch.randn(seq_len, head_dim // 2) + sin = torch.randn(seq_len, head_dim // 2) + + # Apply rotary embedding + result = apply_rotary_emb(x_bf16, cos, sin) + + # Output should be bfloat16 + assert result.dtype == torch.bfloat16 + + # Compare with fp32 computation for numerical accuracy + x_fp32 = x_bf16.to(torch.float32) + result_fp32 = apply_rotary_emb(x_fp32, cos, sin) + + # The bf16 result should be close to the fp32 result when cast to fp32 + torch.testing.assert_close(result.to(torch.float32), result_fp32, rtol=1e-2, atol=1e-2) + + def test_cos_sin_dtype_independence(self): + """Test that cos/sin dtype doesn't affect output dtype""" + batch_size = 2 + seq_len = 4 + num_heads = 8 + head_dim = 64 + + x = torch.randn(batch_size, seq_len, num_heads, head_dim, dtype=torch.float16) + + # Test with different cos/sin dtypes + for cos_sin_dtype in [torch.float32, torch.float16, torch.bfloat16]: + cos = torch.randn(seq_len, head_dim // 2, dtype=cos_sin_dtype) + sin = torch.randn(seq_len, head_dim // 2, dtype=cos_sin_dtype) + + result = apply_rotary_emb(x, cos, sin) + + # Output dtype should match input x dtype, not cos/sin dtype + assert result.dtype == x.dtype + + def test_partial_rotary_float32_computation_with_fp16(self): + """Test that partial rotary also uses float32 computation with fp16 input""" + batch_size = 2 + seq_len = 4 + num_heads = 8 + head_dim = 64 + rotary_dim = 32 # Only rotate half the dimensions + + # Create fp16 input + x_fp16 = torch.randn(batch_size, seq_len, num_heads, head_dim, dtype=torch.float16) + cos = torch.randn(seq_len, rotary_dim // 2) + sin = torch.randn(seq_len, rotary_dim // 2) + + # Store the pass-through part + x_pass_original = x_fp16[..., rotary_dim:].clone() + + # Apply rotary embedding + result = apply_rotary_emb(x_fp16, cos, sin) + + # Output should be fp16 + assert result.dtype == torch.float16 + + # Pass-through dimensions should be exactly preserved (no dtype conversion artifacts) + torch.testing.assert_close(result[..., rotary_dim:], x_pass_original, rtol=0, atol=0) + + # Compare with fp32 computation + x_fp32 = x_fp16.to(torch.float32) + result_fp32 = apply_rotary_emb(x_fp32, cos, sin) + + # The rotated part should be close to fp32 computation + torch.testing.assert_close( + result[..., :rotary_dim].to(torch.float32), result_fp32[..., :rotary_dim], rtol=1e-3, atol=1e-3 + ) + + def test_numerical_stability_with_mixed_dtypes(self): + """Test numerical stability when x, cos, sin have different dtypes""" + batch_size = 2 + seq_len = 4 + num_heads = 8 + head_dim = 64 + + # Test various dtype combinations + dtype_combinations = [ + (torch.float16, torch.float32), + (torch.bfloat16, torch.float32), + (torch.float32, torch.float16), + ] + + for x_dtype, cos_sin_dtype in dtype_combinations: + x = torch.randn(batch_size, seq_len, num_heads, head_dim, dtype=x_dtype) + cos = torch.randn(seq_len, head_dim // 2, dtype=cos_sin_dtype) + sin = torch.randn(seq_len, head_dim // 2, dtype=cos_sin_dtype) + + # Should not raise any errors + result = apply_rotary_emb(x, cos, sin) + + # Output should match input x dtype + assert result.dtype == x_dtype + assert result.shape == x.shape + class TestRotaryEmbedding: """Tests for RotaryEmbedding class""" @@ -536,11 +665,13 @@ def test_different_batch_patterns(self): dtype=torch.float32, ) - position_ids = torch.tensor([ - [0, 1, 2, 3], # Sequential - [0, 0, 1, 1], # Repeated - [10, 20, 30, 40], # Large gaps - ]) + position_ids = torch.tensor( + [ + [0, 1, 2, 3], # Sequential + [0, 0, 1, 1], # Repeated + [10, 20, 30, 40], # Large gaps + ] + ) freqs_cis = position_ids_to_freqs_cis(rope, position_ids, qkv_format="bshd") @@ -601,10 +732,7 @@ def test_freqs_cis_consistency_across_ranks(self, cp_size, cp_rank): if len(indices) > 1: # All tokens at this position should have identical freqs_cis for i in range(1, len(indices)): - torch.testing.assert_close( - freqs_cis_rank[indices[0]], - freqs_cis_rank[indices[i]] - ) + torch.testing.assert_close(freqs_cis_rank[indices[0]], freqs_cis_rank[indices[i]]) def test_freqs_cis_cp_with_variable_sequence_lengths(self): """Test freqs_cis with variable-length sequences and CP splitting""" @@ -697,7 +825,7 @@ def test_full_rope_pipeline(self): # Step 2: Extract cos and sin from freqs_cis # freqs_cis contains concatenated cos and sin cos = freqs_cis[..., :32] # First half is cos - sin = freqs_cis[..., 32:] # Second half is sin + sin = freqs_cis[..., 32:] # Second half is sin # Step 3: Apply rotary embeddings x = torch.randn(batch_size, seq_len, num_heads, 64) diff --git a/tests/unit_tests/moe/test_parallelizer.py b/tests/unit_tests/moe/test_parallelizer.py index 3be38a9c6..bda0f6898 100644 --- a/tests/unit_tests/moe/test_parallelizer.py +++ b/tests/unit_tests/moe/test_parallelizer.py @@ -184,6 +184,17 @@ def create_selective_checkpoint_contexts(policy_factory): # ops.aten.mm.default sentinel aten = types.SimpleNamespace(mm=types.SimpleNamespace(default=object())) torch_stub.ops = types.SimpleNamespace(aten=aten) + + # dtype and device classes for type annotations + class dtype: + pass + + class device: + pass + + torch_stub.dtype = dtype + torch_stub.device = device + # common dtypes referenced by code torch_stub.bfloat16 = object() torch_stub.float32 = object() @@ -644,3 +655,201 @@ def __init__(self): ep_shard_axis_names=None, activation_checkpointing=False, ) + + +def test_apply_fsdp_with_lm_head_precision_fp32(monkeypatch): + """Test that apply_fsdp applies custom MixedPrecisionPolicy to lm_head when lm_head_precision is fp32.""" + P = _import_parallelizer_with_stubs(monkeypatch) + monkeypatch.setattr(P, "MoE", DummyMoE) + + fully_shard_mock = MagicMock() + mp_policy_mock = MagicMock(return_value="MP_POLICY") + monkeypatch.setattr(P, "fully_shard", fully_shard_mock) + monkeypatch.setattr(P, "MixedPrecisionPolicy", mp_policy_mock) + + torch_stub = sys.modules["torch"] + block = DummyBlock(mlp=DummyMoE()) + lm = object() + model = DummyModel([block], lm_head=lm) + fsdp_mesh = object() + + P.apply_fsdp( + model=model, + fsdp_mesh=fsdp_mesh, + pp_enabled=False, + ep_enabled=False, + ep_shard_enabled=False, + lm_head_precision=torch_stub.float32, + ) + + # Find the lm_head call + lm_call = _find_call_by_first_arg(fully_shard_mock, lm) + assert lm_call is not None + _, lm_kwargs = lm_call + + # Verify custom MixedPrecisionPolicy was created with fp32 for all dtypes + assert mp_policy_mock.call_count >= 2 # default + lm_head + # Find the call for lm_head's custom policy + fp32_policy_calls = [ + call for call in mp_policy_mock.call_args_list + if call[1].get("param_dtype") == torch_stub.float32 + and call[1].get("reduce_dtype") == torch_stub.float32 + and call[1].get("output_dtype") == torch_stub.float32 + ] + assert len(fp32_policy_calls) == 1 + + +def test_apply_fsdp_without_lm_head_precision_uses_default_policy(monkeypatch): + """Test that apply_fsdp uses default MixedPrecisionPolicy for lm_head when lm_head_precision is None.""" + P = _import_parallelizer_with_stubs(monkeypatch) + monkeypatch.setattr(P, "MoE", DummyMoE) + + fully_shard_mock = MagicMock() + mp_policy_mock = MagicMock(return_value="MP_POLICY") + monkeypatch.setattr(P, "fully_shard", fully_shard_mock) + monkeypatch.setattr(P, "MixedPrecisionPolicy", mp_policy_mock) + + block = DummyBlock(mlp=DummyMoE()) + lm = object() + model = DummyModel([block], lm_head=lm) + fsdp_mesh = object() + + P.apply_fsdp( + model=model, + fsdp_mesh=fsdp_mesh, + pp_enabled=False, + ep_enabled=False, + ep_shard_enabled=False, + lm_head_precision=None, + ) + + # Find the lm_head call + lm_call = _find_call_by_first_arg(fully_shard_mock, lm) + assert lm_call is not None + + # Should only have one MixedPrecisionPolicy call (the default one) + assert mp_policy_mock.call_count == 1 + + +def test_parallelize_model_passes_lm_head_precision_to_apply_fsdp(monkeypatch): + """Test that parallelize_model passes lm_head_precision to apply_fsdp.""" + P = _import_parallelizer_with_stubs(monkeypatch) + apply_fsdp_mock = MagicMock() + monkeypatch.setattr(P, "apply_fsdp", apply_fsdp_mock) + monkeypatch.setattr(P, "apply_ep", MagicMock()) + monkeypatch.setattr(P, "apply_ac", MagicMock()) + + world_mesh = FakeWorldMesh({("dp",): 2}, mesh_dim_names=["dp"]) + moe_mesh = None + + torch_stub = sys.modules["torch"] + + class Inner: + def __init__(self): + self.moe_config = type("MC", (), {"n_routed_experts": 4})() + + class Outer: + def __init__(self): + self.model = Inner() + + model = Outer() + + P.parallelize_model( + model=model, + world_mesh=world_mesh, + moe_mesh=moe_mesh, + pp_enabled=False, + dp_axis_names=("dp",), + lm_head_precision=torch_stub.float32, + ) + + # Verify apply_fsdp was called with lm_head_precision + apply_fsdp_mock.assert_called_once() + _, kwargs = apply_fsdp_mock.call_args + assert kwargs.get("lm_head_precision") == torch_stub.float32 + + +def test_apply_fsdp_with_lm_head_precision_string_input(monkeypatch): + """Test that apply_fsdp accepts string input for lm_head_precision and converts to torch.dtype.""" + P = _import_parallelizer_with_stubs(monkeypatch) + monkeypatch.setattr(P, "MoE", DummyMoE) + + fully_shard_mock = MagicMock() + mp_policy_mock = MagicMock(return_value="MP_POLICY") + monkeypatch.setattr(P, "fully_shard", fully_shard_mock) + monkeypatch.setattr(P, "MixedPrecisionPolicy", mp_policy_mock) + + torch_stub = sys.modules["torch"] + + # Mock dtype_from_str to convert string to torch.float32 + def mock_dtype_from_str(val, default=None): + if val == "float32" or val == "torch.float32": + return torch_stub.float32 + return default + + monkeypatch.setattr(P, "dtype_from_str", mock_dtype_from_str) + + block = DummyBlock(mlp=DummyMoE()) + lm = object() + model = DummyModel([block], lm_head=lm) + fsdp_mesh = object() + + P.apply_fsdp( + model=model, + fsdp_mesh=fsdp_mesh, + pp_enabled=False, + ep_enabled=False, + ep_shard_enabled=False, + lm_head_precision="float32", + ) + + # Find the lm_head call + lm_call = _find_call_by_first_arg(fully_shard_mock, lm) + assert lm_call is not None + + # Verify custom MixedPrecisionPolicy was created with fp32 for all dtypes + assert mp_policy_mock.call_count >= 2 # default + lm_head + # Find the call for lm_head's custom policy + fp32_policy_calls = [ + call for call in mp_policy_mock.call_args_list + if call[1].get("param_dtype") == torch_stub.float32 + and call[1].get("reduce_dtype") == torch_stub.float32 + and call[1].get("output_dtype") == torch_stub.float32 + ] + assert len(fp32_policy_calls) == 1 + + +def test_parallelize_model_with_lm_head_precision_string_input(monkeypatch): + """Test that parallelize_model accepts string input for lm_head_precision.""" + P = _import_parallelizer_with_stubs(monkeypatch) + apply_fsdp_mock = MagicMock() + monkeypatch.setattr(P, "apply_fsdp", apply_fsdp_mock) + monkeypatch.setattr(P, "apply_ep", MagicMock()) + monkeypatch.setattr(P, "apply_ac", MagicMock()) + + world_mesh = FakeWorldMesh({("dp",): 2}, mesh_dim_names=["dp"]) + moe_mesh = None + + class Inner: + def __init__(self): + self.moe_config = type("MC", (), {"n_routed_experts": 4})() + + class Outer: + def __init__(self): + self.model = Inner() + + model = Outer() + + P.parallelize_model( + model=model, + world_mesh=world_mesh, + moe_mesh=moe_mesh, + pp_enabled=False, + dp_axis_names=("dp",), + lm_head_precision="float32", + ) + + # Verify apply_fsdp was called with lm_head_precision as a string + apply_fsdp_mock.assert_called_once() + _, kwargs = apply_fsdp_mock.call_args + assert kwargs.get("lm_head_precision") == "float32"