Skip to content

Gemma3 27B crashes on 16 nodes w/ FSDP #425

@terrykong

Description

@terrykong

16 nodes full w/ FSDP

Gemma3 27B doesn't work on 16 nodes anymore

https://github.com/NVIDIA-NeMo/RL/blob/ae89e126bedd57020dcb3b5173e393bac287d634/examples/configs/recipes/llm/grpo-gemma3-27b-it-16n8g-fsdp2tp8sp-actckpt-long.yaml

I believe it's due to this PR NVIDIA-NeMo/RL@9f7825e which added the input/output embedding to the parallelize plan whereas before it was replicated.

repro: uv run tests/test_suites/llm/grpo-gemma3-27b-it-16n8g-fsdp2tp8sp-actckpt-long.sh

(DTensorPolicyWorkerV2 pid=1276099, ip=10.65.6.79)                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^ [repeated 24x across cluster]
(DTensorPolicyWorkerV2 pid=1276099, ip=10.65.6.79)   File "/code_snapshots/grpo-gemma3-27b-it-16n8g-fsdp2tp8sp-actckpt-long/3rdparty/Automodel-workspace/Automodel/nemo_automodel/components/distributed/parallelizer.py", line 589, in fsdp2_strategy_parallelize [repeated 24x across cluster]
(DTensorPolicyWorkerV2 pid=1276099, ip=10.65.6.79)     model = fully_shard( [repeated 24x across cluster]
(DTensorPolicyWorkerV2 pid=1276099, ip=10.65.6.79)             ^^^^^^^^^^^^ [repeated 24x across cluster]
(DTensorPolicyWorkerV2 pid=1276099, ip=10.65.6.79)   File "/opt/ray_venvs/nemo_rl.models.policy.dtensor_policy_worker_v2.DTensorPolicyWorkerV2/lib/python3.12/site-packages/torch/distributed/_composable/contract.py", line 150, in wrapper [repeated 24x across cluster]
(DTensorPolicyWorkerV2 pid=1276099, ip=10.65.6.79)     updated = func(inp_module, *args, **kwargs) [repeated 24x across cluster]
(DTensorPolicyWorkerV2 pid=1276099, ip=10.65.6.79)               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [repeated 24x across cluster]
(DTensorPolicyWorkerV2 pid=1276099, ip=10.65.6.79)   File "/opt/ray_venvs/nemo_rl.models.policy.dtensor_policy_worker_v2.DTensorPolicyWorkerV2/lib/python3.12/site-packages/torch/distributed/fsdp/_fully_shard/_fully_shard.py", line 217, in fully_shard [repeated 24x across cluster]
(DTensorPolicyWorkerV2 pid=1276099, ip=10.65.6.79)     state._fsdp_param_group = FSDPParamGroup( [repeated 24x across cluster]
(DTensorPolicyWorkerV2 pid=1276099, ip=10.65.6.79)                               ^^^^^^^^^^^^^^^ [repeated 24x across cluster]
(DTensorPolicyWorkerV2 pid=1276099, ip=10.65.6.79)     FSDPParam( [repeated 24x across cluster]
(DTensorPolicyWorkerV2 pid=1276099, ip=10.65.6.79)     self._init_sharded_param(param, device, shard_placement_fn) [repeated 24x across cluster]
(DTensorPolicyWorkerV2 pid=1276099, ip=10.65.6.79)   File "/opt/ray_venvs/nemo_rl.models.policy.dtensor_policy_worker_v2.DTensorPolicyWorkerV2/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context [repeated 24x across cluster]
(DTensorPolicyWorkerV2 pid=1276099, ip=10.65.6.79)     return func(*args, **kwargs) [repeated 24x across cluster]
(DTensorPolicyWorkerV2 pid=1276099, ip=10.65.6.79)            ^^^^^^^^^^^^^^^^^^^^^ [repeated 24x across cluster]
(DTensorPolicyWorkerV2 pid=1276099, ip=10.65.6.79)   File "/opt/ray_venvs/nemo_rl.models.policy.dtensor_policy_worker_v2.DTensorPolicyWorkerV2/lib/python3.12/site-packages/torch/distributed/fsdp/_fully_shard/_fsdp_param.py", line 335, in _init_sharded_param [repeated 24x across cluster]
(DTensorPolicyWorkerV2 pid=1276099, ip=10.65.6.79)     raise NotImplementedError( [repeated 24x across cluster]
(DTensorPolicyWorkerV2 pid=1276099, ip=10.65.6.79) NotImplementedError: FSDP+TP sharding does not support uneven sharding for now: tensor dim 0 has size 262208 which cannot be evenly sharded into 128 shards. [repeated 24x across cluster]

https://wandb.ai/nvidia/nemo-rl/runs/z55enf3w

Would HSDP help?

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions