Skip to content

Commit c08c9d4

Browse files
[float8 moe training] validate float8 moe parallelism config (#1360)
## Summary Validate only FSDP, HSDP are used for float8 MoE training. TP support is in progress and CP/PP are untested. 2D+ parallelism are untested as well. ## Test plan - Command: `NGPU=4 CONFIG_FILE="./torchtitan/experiments/llama4/train_configs/debug_model.toml" ./run_train.sh --training.steps=10 --model.converters="float8" --float8.recipe_name="rowwise" --float8.moe_fqns_prototype="experts" --parallelism.tensor_parallel_degree=2` - Error: `AssertionError: Float8 MoE training prototype does not yet support tensor parallelism`
1 parent c31ff8b commit c08c9d4

File tree

1 file changed

+12
-0
lines changed

1 file changed

+12
-0
lines changed

torchtitan/components/quantization/float8.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,18 @@ def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims):
5757
self.moe_fqns = float8_config.moe_fqns_prototype
5858
self.filter_fn = self._init_filter_fn(float8_config)
5959

60+
# Validate MoE training prototype limitations.
61+
if self.moe_fqns:
62+
assert (
63+
job_config.parallelism.tensor_parallel_degree == 1
64+
), "Float8 MoE training prototype does not yet support tensor parallelism"
65+
assert (
66+
job_config.parallelism.pipeline_parallel_degree == 1
67+
), "Float8 MoE training prototype does not yet support pipeline parallelism"
68+
assert (
69+
job_config.parallelism.context_parallel_degree == 1
70+
), "Float8 MoE training prototype does not yet support context parallelism"
71+
6072
if float8_config.recipe_name is not None:
6173
assert not float8_config.enable_fsdp_float8_all_gather, (
6274
"using `float8_config.enable_fsdp_float8_all_gather` together "

0 commit comments

Comments
 (0)