Skip to content

Commit 42b0beb

Browse files
[float8 moe training] Do MoE conversion before Float8Linear conversion to avoid unconverting Float8Linears (#1359)
## Summary MoE conversion must take place before Float8Linear conversion, otherwise the Float8Linears will be converted back to nn.Linear (see [here](https://github.com/pytorch/ao/blob/c2a6568a04075acc371a338206216bb65536fb27/torchao/quantization/quant_api.py#L294-L299)). ## Next steps - Add warning in torchao when this happens, or find a better way to avoid this.
1 parent 8004fef commit 42b0beb

File tree

1 file changed

+21
-14
lines changed

1 file changed

+21
-14
lines changed

torchtitan/components/quantization/float8.py

Lines changed: 21 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -102,21 +102,12 @@ def convert(self, model: nn.Module):
102102
if not self.enabled:
103103
return
104104

105-
from torchao.float8 import convert_to_float8_training
106-
107-
# Mutates the model inplace replacing instances of nn.Linear with Float8Linear
108-
convert_to_float8_training(
109-
model,
110-
config=self.config,
111-
module_filter_fn=partial(module_filter_fn, filter_fqns=self.filter_fqns),
112-
)
113-
logger.info(
114-
"Swapped to Float8Linear layers with enable_fsdp_float8_all_gather="
115-
f"{self.config.enable_fsdp_float8_all_gather}"
116-
)
117-
118105
# Mutates the model inplace replacing instances of nn.Parameter with ScaledGroupedMMTensor,
119106
# to perform dynamic float8 rowwise quantization + scaled grouped GEMMs for the target MoE FQNs.
107+
# MoE conversion must take place before Float8Linear conversion, otherwise the Float8Linears will
108+
# be converted back to nn.Linear:
109+
# https://github.com/pytorch/ao/blob/c2a6568a04075acc371a338206216bb65536fb27/torchao/quantization/quant_api.py#L294-L299
110+
# TODO: add warning in torchao when this happens, or find a better way to avoid this.
120111
if self.moe_fqns:
121112
from torchao.quantization.quant_api import quantize_
122113

@@ -137,7 +128,23 @@ def moe_module_filter_fn(mod: nn.Module, cur_fqn: str) -> bool:
137128

138129
config = MoETrainingConfig()
139130
quantize_(model, config=config, filter_fn=moe_module_filter_fn)
140-
logger.info("Converted MoE to float8")
131+
logger.info(
132+
f"Converted MoE layers matching FQNS {self.moe_fqns} "
133+
"to use dynamic float8 rowwise quantization with scaled grouped GEMMs"
134+
)
135+
136+
from torchao.float8 import convert_to_float8_training
137+
138+
# Mutates the model inplace replacing instances of nn.Linear with Float8Linear
139+
convert_to_float8_training(
140+
model,
141+
config=self.config,
142+
module_filter_fn=partial(module_filter_fn, filter_fqns=self.filter_fqns),
143+
)
144+
logger.info(
145+
"Swapped to Float8Linear layers with enable_fsdp_float8_all_gather="
146+
f"{self.config.enable_fsdp_float8_all_gather}"
147+
)
141148

142149
def post_optimizer_hook(self, model: nn.Module | list[nn.Module]):
143150
if not self.enabled:

0 commit comments

Comments
 (0)