diff --git a/torchtitan/components/quantization/float8.py b/torchtitan/components/quantization/float8.py index 782889716..c7466a34c 100644 --- a/torchtitan/components/quantization/float8.py +++ b/torchtitan/components/quantization/float8.py @@ -55,6 +55,18 @@ def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims): self.filter_fqns = float8_config.filter_fqns self.moe_fqns = float8_config.moe_fqns_prototype + # Validate MoE training prototype limitations. + if self.moe_fqns: + assert ( + job_config.parallelism.tensor_parallel_degree == 1 + ), "Float8 MoE training prototype does not yet support tensor parallelism" + assert ( + job_config.parallelism.pipeline_parallel_degree == 1 + ), "Float8 MoE training prototype does not yet support pipeline parallelism" + assert ( + job_config.parallelism.context_parallel_degree == 1 + ), "Float8 MoE training prototype does not yet support context parallelism" + if float8_config.recipe_name is not None: assert ( not float8_config.enable_fsdp_float8_all_gather