@@ -102,21 +102,12 @@ def convert(self, model: nn.Module):
102
102
if not self .enabled :
103
103
return
104
104
105
- from torchao .float8 import convert_to_float8_training
106
-
107
- # Mutates the model inplace replacing instances of nn.Linear with Float8Linear
108
- convert_to_float8_training (
109
- model ,
110
- config = self .config ,
111
- module_filter_fn = partial (module_filter_fn , filter_fqns = self .filter_fqns ),
112
- )
113
- logger .info (
114
- "Swapped to Float8Linear layers with enable_fsdp_float8_all_gather="
115
- f"{ self .config .enable_fsdp_float8_all_gather } "
116
- )
117
-
118
105
# Mutates the model inplace replacing instances of nn.Parameter with ScaledGroupedMMTensor,
119
106
# to perform dynamic float8 rowwise quantization + scaled grouped GEMMs for the target MoE FQNs.
107
+ # MoE conversion must take place before Float8Linear conversion, otherwise the Float8Linears will
108
+ # be converted back to nn.Linear:
109
+ # https://github.com/pytorch/ao/blob/c2a6568a04075acc371a338206216bb65536fb27/torchao/quantization/quant_api.py#L294-L299
110
+ # TODO: add warning in torchao when this happens, or find a better way to avoid this.
120
111
if self .moe_fqns :
121
112
from torchao .quantization .quant_api import quantize_
122
113
@@ -137,7 +128,23 @@ def moe_module_filter_fn(mod: nn.Module, cur_fqn: str) -> bool:
137
128
138
129
config = MoETrainingConfig ()
139
130
quantize_ (model , config = config , filter_fn = moe_module_filter_fn )
140
- logger .info ("Converted MoE to float8" )
131
+ logger .info (
132
+ f"Converted MoE layers matching FQNS { self .moe_fqns } "
133
+ "to use dynamic float8 rowwise quantization with scaled grouped GEMMs"
134
+ )
135
+
136
+ from torchao .float8 import convert_to_float8_training
137
+
138
+ # Mutates the model inplace replacing instances of nn.Linear with Float8Linear
139
+ convert_to_float8_training (
140
+ model ,
141
+ config = self .config ,
142
+ module_filter_fn = partial (module_filter_fn , filter_fqns = self .filter_fqns ),
143
+ )
144
+ logger .info (
145
+ "Swapped to Float8Linear layers with enable_fsdp_float8_all_gather="
146
+ f"{ self .config .enable_fsdp_float8_all_gather } "
147
+ )
141
148
142
149
def post_optimizer_hook (self , model : nn .Module | list [nn .Module ]):
143
150
if not self .enabled :
0 commit comments