diff --git a/torchtitan/components/quantization/float8.py b/torchtitan/components/quantization/float8.py index 4a7e2a651..6043bbada 100644 --- a/torchtitan/components/quantization/float8.py +++ b/torchtitan/components/quantization/float8.py @@ -101,6 +101,19 @@ def convert(self, model: nn.Module): if not self.enabled: return + from torchao.quantization import quantize_ + from torchao.prototype.deep_gemm_float8_training.linear import ( + DeepGemmFloat8LinearConfig, + ) + + quantize_( + model, + config=DeepGemmFloat8LinearConfig(), + filter_fn=lambda mod, fqn: isinstance(mod, torch.nn.Linear) and fqn != "output", + ) + logger.info("enabled DeepGemm dense training") + return + from torchao.float8 import convert_to_float8_training # Mutates the model inplace replacing instances of nn.Linear with Float8Linear