From f7cd64cdb481b08d6238e804af75f6c182c60791 Mon Sep 17 00:00:00 2001 From: vasiliy Date: Wed, 18 Jun 2025 06:07:44 -0700 Subject: [PATCH] [not for land] testing out float8 128_1_128_128 blockwise scaling Summary: Test drive of https://github.com/pytorch/ao/pull/2386, not for land Test Plan: ```bash with-proxy CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml" ./run_train.sh --model.converters float8 --model.print_after_conversion ``` Reviewers: Subscribers: Tasks: Tags: --- torchtitan/components/quantization/float8.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/torchtitan/components/quantization/float8.py b/torchtitan/components/quantization/float8.py index 4a7e2a651..6043bbada 100644 --- a/torchtitan/components/quantization/float8.py +++ b/torchtitan/components/quantization/float8.py @@ -101,6 +101,19 @@ def convert(self, model: nn.Module): if not self.enabled: return + from torchao.quantization import quantize_ + from torchao.prototype.deep_gemm_float8_training.linear import ( + DeepGemmFloat8LinearConfig, + ) + + quantize_( + model, + config=DeepGemmFloat8LinearConfig(), + filter_fn=lambda mod, fqn: isinstance(mod, torch.nn.Linear) and fqn != "output", + ) + logger.info("enabled DeepGemm dense training") + return + from torchao.float8 import convert_to_float8_training # Mutates the model inplace replacing instances of nn.Linear with Float8Linear