@@ -767,3 +767,89 @@ def __init__(self):
767767 apply_fsdp_mock .assert_called_once ()
768768 _ , kwargs = apply_fsdp_mock .call_args
769769 assert kwargs .get ("lm_head_precision" ) == torch_stub .float32
770+
771+
772+ def test_apply_fsdp_with_lm_head_precision_string_input (monkeypatch ):
773+ """Test that apply_fsdp accepts string input for lm_head_precision and converts to torch.dtype."""
774+ P = _import_parallelizer_with_stubs (monkeypatch )
775+ monkeypatch .setattr (P , "MoE" , DummyMoE )
776+
777+ fully_shard_mock = MagicMock ()
778+ mp_policy_mock = MagicMock (return_value = "MP_POLICY" )
779+ monkeypatch .setattr (P , "fully_shard" , fully_shard_mock )
780+ monkeypatch .setattr (P , "MixedPrecisionPolicy" , mp_policy_mock )
781+
782+ torch_stub = sys .modules ["torch" ]
783+
784+ # Mock dtype_from_str to convert string to torch.float32
785+ def mock_dtype_from_str (val , default = None ):
786+ if val == "float32" or val == "torch.float32" :
787+ return torch_stub .float32
788+ return default
789+
790+ monkeypatch .setattr (P , "dtype_from_str" , mock_dtype_from_str )
791+
792+ block = DummyBlock (mlp = DummyMoE ())
793+ lm = object ()
794+ model = DummyModel ([block ], lm_head = lm )
795+ fsdp_mesh = object ()
796+
797+ P .apply_fsdp (
798+ model = model ,
799+ fsdp_mesh = fsdp_mesh ,
800+ pp_enabled = False ,
801+ ep_enabled = False ,
802+ ep_shard_enabled = False ,
803+ lm_head_precision = "float32" ,
804+ )
805+
806+ # Find the lm_head call
807+ lm_call = _find_call_by_first_arg (fully_shard_mock , lm )
808+ assert lm_call is not None
809+
810+ # Verify custom MixedPrecisionPolicy was created with fp32 for all dtypes
811+ assert mp_policy_mock .call_count >= 2 # default + lm_head
812+ # Find the call for lm_head's custom policy
813+ fp32_policy_calls = [
814+ call for call in mp_policy_mock .call_args_list
815+ if call [1 ].get ("param_dtype" ) == torch_stub .float32
816+ and call [1 ].get ("reduce_dtype" ) == torch_stub .float32
817+ and call [1 ].get ("output_dtype" ) == torch_stub .float32
818+ ]
819+ assert len (fp32_policy_calls ) == 1
820+
821+
822+ def test_parallelize_model_with_lm_head_precision_string_input (monkeypatch ):
823+ """Test that parallelize_model accepts string input for lm_head_precision."""
824+ P = _import_parallelizer_with_stubs (monkeypatch )
825+ apply_fsdp_mock = MagicMock ()
826+ monkeypatch .setattr (P , "apply_fsdp" , apply_fsdp_mock )
827+ monkeypatch .setattr (P , "apply_ep" , MagicMock ())
828+ monkeypatch .setattr (P , "apply_ac" , MagicMock ())
829+
830+ world_mesh = FakeWorldMesh ({("dp" ,): 2 }, mesh_dim_names = ["dp" ])
831+ moe_mesh = None
832+
833+ class Inner :
834+ def __init__ (self ):
835+ self .moe_config = type ("MC" , (), {"n_routed_experts" : 4 })()
836+
837+ class Outer :
838+ def __init__ (self ):
839+ self .model = Inner ()
840+
841+ model = Outer ()
842+
843+ P .parallelize_model (
844+ model = model ,
845+ world_mesh = world_mesh ,
846+ moe_mesh = moe_mesh ,
847+ pp_enabled = False ,
848+ dp_axis_names = ("dp" ,),
849+ lm_head_precision = "float32" ,
850+ )
851+
852+ # Verify apply_fsdp was called with lm_head_precision as a string
853+ apply_fsdp_mock .assert_called_once ()
854+ _ , kwargs = apply_fsdp_mock .call_args
855+ assert kwargs .get ("lm_head_precision" ) == "float32"
0 commit comments