Skip to content

Commit 3579886

Browse files
committed
fix
Signed-off-by: Hemil Desai <hemild@nvidia.com>
1 parent 5ffda53 commit 3579886

File tree

2 files changed

+92
-6
lines changed

2 files changed

+92
-6
lines changed

nemo_automodel/components/moe/parallelizer.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
GroupedExpertsDeepEP,
3535
MoE,
3636
)
37-
from nemo_automodel.components.moe.utils import BackendConfig
37+
from nemo_automodel.shared.utils import dtype_from_str
3838

3939
logger = logging.getLogger(__name__)
4040
_CP_STREAM = None
@@ -130,9 +130,11 @@ def apply_fsdp(
130130
mp_policy: MixedPrecisionPolicy | None = None,
131131
offload_policy: OffloadPolicy | None = None,
132132
reshard_after_forward: bool = False,
133-
backend_config: BackendConfig | None = None,
134-
lm_head_precision: torch.dtype | None = None,
133+
lm_head_precision: str | torch.dtype | None = None,
135134
):
135+
if isinstance(lm_head_precision, str):
136+
lm_head_precision = dtype_from_str(lm_head_precision, default=None)
137+
136138
if mp_policy is None:
137139
mp_policy = MixedPrecisionPolicy(
138140
param_dtype=torch.bfloat16, reduce_dtype=torch.float32, output_dtype=torch.bfloat16
@@ -232,8 +234,7 @@ def parallelize_model(
232234
ep_shard_axis_names: tuple[str, ...] | None = None,
233235
activation_checkpointing: bool = False,
234236
reshard_after_forward: bool = False,
235-
backend_config: BackendConfig | None = None,
236-
lm_head_precision: torch.dtype | None = None,
237+
lm_head_precision: str | torch.dtype | None = None,
237238
):
238239
assert tp_axis_name is None or world_mesh[tp_axis_name].size() == 1, (
239240
"Tensor parallelism not supported for custom MoE models"
@@ -271,6 +272,5 @@ def parallelize_model(
271272
ep_shard_enabled=ep_shard_mesh is not None and ep_shard_mesh.size() > 1,
272273
ep_shard_mesh=ep_shard_mesh,
273274
reshard_after_forward=reshard_after_forward,
274-
backend_config=backend_config,
275275
lm_head_precision=lm_head_precision,
276276
)

tests/unit_tests/moe/test_parallelizer.py

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -767,3 +767,89 @@ def __init__(self):
767767
apply_fsdp_mock.assert_called_once()
768768
_, kwargs = apply_fsdp_mock.call_args
769769
assert kwargs.get("lm_head_precision") == torch_stub.float32
770+
771+
772+
def test_apply_fsdp_with_lm_head_precision_string_input(monkeypatch):
773+
"""Test that apply_fsdp accepts string input for lm_head_precision and converts to torch.dtype."""
774+
P = _import_parallelizer_with_stubs(monkeypatch)
775+
monkeypatch.setattr(P, "MoE", DummyMoE)
776+
777+
fully_shard_mock = MagicMock()
778+
mp_policy_mock = MagicMock(return_value="MP_POLICY")
779+
monkeypatch.setattr(P, "fully_shard", fully_shard_mock)
780+
monkeypatch.setattr(P, "MixedPrecisionPolicy", mp_policy_mock)
781+
782+
torch_stub = sys.modules["torch"]
783+
784+
# Mock dtype_from_str to convert string to torch.float32
785+
def mock_dtype_from_str(val, default=None):
786+
if val == "float32" or val == "torch.float32":
787+
return torch_stub.float32
788+
return default
789+
790+
monkeypatch.setattr(P, "dtype_from_str", mock_dtype_from_str)
791+
792+
block = DummyBlock(mlp=DummyMoE())
793+
lm = object()
794+
model = DummyModel([block], lm_head=lm)
795+
fsdp_mesh = object()
796+
797+
P.apply_fsdp(
798+
model=model,
799+
fsdp_mesh=fsdp_mesh,
800+
pp_enabled=False,
801+
ep_enabled=False,
802+
ep_shard_enabled=False,
803+
lm_head_precision="float32",
804+
)
805+
806+
# Find the lm_head call
807+
lm_call = _find_call_by_first_arg(fully_shard_mock, lm)
808+
assert lm_call is not None
809+
810+
# Verify custom MixedPrecisionPolicy was created with fp32 for all dtypes
811+
assert mp_policy_mock.call_count >= 2 # default + lm_head
812+
# Find the call for lm_head's custom policy
813+
fp32_policy_calls = [
814+
call for call in mp_policy_mock.call_args_list
815+
if call[1].get("param_dtype") == torch_stub.float32
816+
and call[1].get("reduce_dtype") == torch_stub.float32
817+
and call[1].get("output_dtype") == torch_stub.float32
818+
]
819+
assert len(fp32_policy_calls) == 1
820+
821+
822+
def test_parallelize_model_with_lm_head_precision_string_input(monkeypatch):
823+
"""Test that parallelize_model accepts string input for lm_head_precision."""
824+
P = _import_parallelizer_with_stubs(monkeypatch)
825+
apply_fsdp_mock = MagicMock()
826+
monkeypatch.setattr(P, "apply_fsdp", apply_fsdp_mock)
827+
monkeypatch.setattr(P, "apply_ep", MagicMock())
828+
monkeypatch.setattr(P, "apply_ac", MagicMock())
829+
830+
world_mesh = FakeWorldMesh({("dp",): 2}, mesh_dim_names=["dp"])
831+
moe_mesh = None
832+
833+
class Inner:
834+
def __init__(self):
835+
self.moe_config = type("MC", (), {"n_routed_experts": 4})()
836+
837+
class Outer:
838+
def __init__(self):
839+
self.model = Inner()
840+
841+
model = Outer()
842+
843+
P.parallelize_model(
844+
model=model,
845+
world_mesh=world_mesh,
846+
moe_mesh=moe_mesh,
847+
pp_enabled=False,
848+
dp_axis_names=("dp",),
849+
lm_head_precision="float32",
850+
)
851+
852+
# Verify apply_fsdp was called with lm_head_precision as a string
853+
apply_fsdp_mock.assert_called_once()
854+
_, kwargs = apply_fsdp_mock.call_args
855+
assert kwargs.get("lm_head_precision") == "float32"

0 commit comments

Comments
 (0)