Skip to content

Commit bc7f40a

Browse files
danielvegamyhrewwwjn
authored andcommitted
[float8] add float8 rowwise MoE prototype (#1245)
# Summary - Adds `--float8.moe_fqns_prototype="..."` option to float8 training API - API accepts a comma-separated list of FQNs to apply MoE float8 training conversion to. - `quanttize_` with the `MoETrainingConfig` will recursively swap nn.Parameter data tensors to a tensor subclass, which has an override for grouped_mm => [dynamic quant + scaled grouped mm](https://github.com/pytorch/ao/blob/d963a8840e3c228e303fe14aff5d9be7017c92b6/torchao/prototype/scaled_grouped_mm/scaled_grouped_mm.py#L20) prototype. Context: see implementation of GroupedExperts [here](https://github.com/pytorch/torchtitan/blob/ca10545e41582fed4ebb00db4c13db71194a0dfa/torchtitan/experiments/llama4/model/moe.py#L85-L87). # Testing - Tested via manual testing with torchao `convert_moe_to_float8_training` prototype ([PR](pytorch/ao#2275)) and confirmed single GPU training works as expected. # Limitations - Only supports single GPU training so far. - Only performs grouped_mm override for routed experts (see condition [here](https://github.com/pytorch/ao/pull/2275/files#diff-c529b94621368096076db6bec8a6fc058d7f7595c39cd59965c657ed5dea861cR29-R33)). For shared experts, I'll need to update the torchao prototype to support 3d A tensor (see torchtitan [here](https://github.com/pytorch/torchtitan/blob/ca10545e41582fed4ebb00db4c13db71194a0dfa/torchtitan/experiments/llama4/model/moe.py#L316)).
1 parent f3811a9 commit bc7f40a

File tree

3 files changed

+33
-0
lines changed

3 files changed

+33
-0
lines changed

torchtitan/components/quantization/float8.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims):
5353

5454
self.enabled = True
5555
self.filter_fqns = float8_config.filter_fqns
56+
self.moe_fqns = float8_config.moe_fqns_prototype
5657

5758
if float8_config.recipe_name is not None:
5859
assert (
@@ -114,6 +115,30 @@ def convert(self, model: nn.Module):
114115
f"{self.config.enable_fsdp_float8_all_gather}"
115116
)
116117

118+
# Mutates the model inplace replacing instances of nn.Parameter with ScaledGroupedMMTensor,
119+
# to perform dynamic float8 rowwise quantization + scaled grouped GEMMs for the target MoE FQNs.
120+
if self.moe_fqns:
121+
from torchao.quantization.quant_api import quantize_
122+
123+
try:
124+
from torchao.prototype.moe_training.conversion_utils import (
125+
MoETrainingConfig,
126+
)
127+
except ImportError as e:
128+
raise ImportError(
129+
"torchao installation does not have MoE training support. Please install torchao nightly build."
130+
) from e
131+
132+
def moe_module_filter_fn(mod: nn.Module, cur_fqn: str) -> bool:
133+
for target_fqn in self.moe_fqns:
134+
if target_fqn in cur_fqn:
135+
return True
136+
return False
137+
138+
config = MoETrainingConfig()
139+
quantize_(model, config=config, filter_fn=moe_module_filter_fn)
140+
logger.info("Converted MoE to float8")
141+
117142
def post_optimizer_hook(self, model: nn.Module | list[nn.Module]):
118143
if not self.enabled:
119144
return

torchtitan/config_manager.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -513,6 +513,13 @@ class Float8:
513513
Not compatible with torch.compile.
514514
"""
515515

516+
moe_fqns_prototype: list[str] | str = field(default_factory=list)
517+
"""
518+
Comma-separated list of fully qualified names of MoE modules to apply float8 rowwise training to.
519+
This is a prototype feature that requires the torchao nightly build.
520+
Example: --float8.moe_fqns_prototype="experts"
521+
"""
522+
516523

517524
@dataclass
518525
class MX:

torchtitan/experiments/llama4/train_configs/debug_model.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,3 +69,4 @@ selective_ac_option = '2' # 'int' = ac every positive int layer or 'op', ac bas
6969
enable_fsdp_float8_all_gather = false
7070
precompute_float8_dynamic_scale_for_fsdp = false
7171
filter_fqns = ["output", "router.gate"]
72+
moe_fqns = ["experts"]

0 commit comments

Comments
 (0)