From a840ef535dbf91bdfca97742d469a541eb33c1e1 Mon Sep 17 00:00:00 2001 From: wengshiy Date: Mon, 16 Jun 2025 13:57:11 +0000 Subject: [PATCH 01/11] quantize_affine_float8/dequantize_affine_float8 not decomposed on inductor --- test/float8/test_compile.py | 59 ++++++++++++++++++++++++ torchao/quantization/quant_primitives.py | 24 +++++++++- torchao/utils.py | 11 +++-- 3 files changed, 89 insertions(+), 5 deletions(-) diff --git a/test/float8/test_compile.py b/test/float8/test_compile.py index ac5d1f8d96..fb4d6ea316 100644 --- a/test/float8/test_compile.py +++ b/test/float8/test_compile.py @@ -37,6 +37,10 @@ hp_tensor_to_float8_dynamic, ) from torchao.float8.float8_tensor import GemmInputRole, LinearMMConfig, ScaledMMConfig +from torchao.quantization.quant_primitives import ( + dequantize_affine_float8, + quantize_affine_float8, +) from torchao.testing.float8.test_utils import get_test_float8_linear_config @@ -392,5 +396,60 @@ def test_dynamic_scale_numeric_parity( assert torch.equal(float8_eager._data, float8_compile._data) +@pytest.mark.parametrize( + "float8_dtype", + [ + torch.float8_e4m3fn, + torch.float8_e5m2, + ], +) +@pytest.mark.parametrize( + "hp_dtype", + [ + torch.float32, + torch.float16, + torch.bfloat16, + ], +) +@unittest.skipIf( + not TORCH_VERSION_AT_LEAST_2_5, "skipping when torch version is 2.5 or lower" +) +def test_quantize_dequantize_fp8_inductor(float8_dtype, hp_dtype): + input = torch.randn(10, 10) + with torch.no_grad(): + torch._dynamo.reset() + expected_scale = torch.tensor(2.0) + expected_quantized = quantize_affine_float8( + input, + expected_scale, + float8_dtype, + ) + expected_dequantized = dequantize_affine_float8( + expected_quantized, + expected_scale, + output_dtype=hp_dtype, + ) + test_q, (code_q,) = torch._inductor.utils.run_and_get_code( + torch.compile(quantize_affine_float8), + input, + expected_scale, + float8_dtype, + ) + torch.testing.FileCheck().check( + "torch.ops.torchao.quantize_affine_float8.default" + ).run(code_q) + test_dq, (code_dq,) = torch._inductor.utils.run_and_get_code( + torch.compile(dequantize_affine_float8), + test_q, + expected_scale, + hp_dtype, + ) + torch.testing.FileCheck().check( + "torch.ops.torchao.dequantize_affine_float8.default" + ).run(code_dq) + torch.testing.assert_close(expected_quantized, test_q) + torch.testing.assert_close(expected_dequantized, test_dq) + + if __name__ == "__main__": pytest.main([__file__]) diff --git a/torchao/quantization/quant_primitives.py b/torchao/quantization/quant_primitives.py index df136bc06e..72b3935157 100644 --- a/torchao/quantization/quant_primitives.py +++ b/torchao/quantization/quant_primitives.py @@ -2270,10 +2270,11 @@ def _expand_scale_to_tensor_shape( return expanded_scale +@_register_custom_op(quant_lib, False) def _quantize_affine_float8( tensor: torch.Tensor, scale: torch.Tensor, - float8_dtype: torch.dtype = torch.float8_e4m3fn, + float8_dtype: torch.dtype, ) -> torch.Tensor: """ Quantizes the high precision floating point tensor to a float8 tensor, using the given scaling factor. @@ -2290,10 +2291,20 @@ def _quantize_affine_float8( return fp8_tensor +@torch.library.impl(quant_lib, "quantize_affine_float8", "Meta") +def _quantize_affine_float8_meta( + tensor: torch.Tensor, + scale: torch.Tensor, + float8_dtype: torch.dtype, +) -> torch.Tensor: + return torch.empty_like(tensor, dtype=float8_dtype) + + +@_register_custom_op(quant_lib, False) def _dequantize_affine_float8( tensor: torch.Tensor, scale: torch.Tensor, - output_dtype: torch.dtype = torch.float32, + output_dtype: torch.dtype, ) -> torch.Tensor: """ Dequantizes the float8 tensor to high precision tensor. @@ -2305,3 +2316,12 @@ def _dequantize_affine_float8( hp_tensor = fp8_tensor * scale_expanded return hp_tensor.to(output_dtype) + + +@torch.library.impl(quant_lib, "dequantize_affine_float8", "Meta") +def _dequantize_affine_float8_meta( + tensor: torch.Tensor, + scale: torch.Tensor, + output_dtype: torch.dtype, +) -> torch.Tensor: + return torch.empty_like(tensor, dtype=output_dtype) diff --git a/torchao/utils.py b/torchao/utils.py index 416d23d785..99a0a729f5 100644 --- a/torchao/utils.py +++ b/torchao/utils.py @@ -179,7 +179,7 @@ def find_multiple(n: int, *args: int) -> int: return n + k - (n % k) -def _register_custom_op(lib): +def _register_custom_op(lib, implicit=True): """This decorator is used to preserve some high level operators for torch.export.export while still allow them to be decomposed for inductor path @@ -206,6 +206,10 @@ def _the_op_that_needs_to_be_preserved(...) """ from torch._inductor.decomposition import register_decomposition + dispatch_key = ( + "CompositeImplicitAutograd" if implicit else "CompositeExplicitAutograd" + ) + def decorator(fn): if TORCH_VERSION_AT_LEAST_2_5: from torch._library.infer_schema import infer_schema @@ -221,11 +225,12 @@ def decorator(fn): op_name = fn.__name__[1:] schema = op_name + infer_schema(fn, mutates_args={}) lib.define(schema) - lib.impl(op_name, fn, "CompositeImplicitAutograd") + lib.impl(op_name, fn, dispatch_key) lib_namespace = lib.ns op = getattr(getattr(torch.ops, lib_namespace), op_name) - register_decomposition([op])(fn) + if implicit: + register_decomposition([op])(fn) return op else: return fn From 02d045b267c20cca276770682d8769ec9d8c258b Mon Sep 17 00:00:00 2001 From: wengshiy Date: Mon, 16 Jun 2025 14:31:21 +0000 Subject: [PATCH 02/11] remove redundant unittest.skipIf --- test/float8/test_compile.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/test/float8/test_compile.py b/test/float8/test_compile.py index fb4d6ea316..43f0d8e2f2 100644 --- a/test/float8/test_compile.py +++ b/test/float8/test_compile.py @@ -411,9 +411,6 @@ def test_dynamic_scale_numeric_parity( torch.bfloat16, ], ) -@unittest.skipIf( - not TORCH_VERSION_AT_LEAST_2_5, "skipping when torch version is 2.5 or lower" -) def test_quantize_dequantize_fp8_inductor(float8_dtype, hp_dtype): input = torch.randn(10, 10) with torch.no_grad(): From 9860c56e87ef83986f9ca25b87c32e7a3023f186 Mon Sep 17 00:00:00 2001 From: wengshiy Date: Wed, 18 Jun 2025 15:44:02 +0000 Subject: [PATCH 03/11] fix rebase issue --- test/float8/test_compile.py | 10 ++++------ torchao/quantization/quant_primitives.py | 8 ++++---- 2 files changed, 8 insertions(+), 10 deletions(-) diff --git a/test/float8/test_compile.py b/test/float8/test_compile.py index 43f0d8e2f2..64feaf7b5d 100644 --- a/test/float8/test_compile.py +++ b/test/float8/test_compile.py @@ -37,10 +37,6 @@ hp_tensor_to_float8_dynamic, ) from torchao.float8.float8_tensor import GemmInputRole, LinearMMConfig, ScaledMMConfig -from torchao.quantization.quant_primitives import ( - dequantize_affine_float8, - quantize_affine_float8, -) from torchao.testing.float8.test_utils import get_test_float8_linear_config @@ -412,6 +408,8 @@ def test_dynamic_scale_numeric_parity( ], ) def test_quantize_dequantize_fp8_inductor(float8_dtype, hp_dtype): + quantize_affine_float8 = torch.ops.torchao.quantize_affine_float8 + dequantize_affine_float8 = torch.ops.torchao.dequantize_affine_float8 input = torch.randn(10, 10) with torch.no_grad(): torch._dynamo.reset() @@ -419,7 +417,7 @@ def test_quantize_dequantize_fp8_inductor(float8_dtype, hp_dtype): expected_quantized = quantize_affine_float8( input, expected_scale, - float8_dtype, + float8_dtype=float8_dtype, ) expected_dequantized = dequantize_affine_float8( expected_quantized, @@ -430,7 +428,7 @@ def test_quantize_dequantize_fp8_inductor(float8_dtype, hp_dtype): torch.compile(quantize_affine_float8), input, expected_scale, - float8_dtype, + float8_dtype=float8_dtype, ) torch.testing.FileCheck().check( "torch.ops.torchao.quantize_affine_float8.default" diff --git a/torchao/quantization/quant_primitives.py b/torchao/quantization/quant_primitives.py index 72b3935157..56e8422197 100644 --- a/torchao/quantization/quant_primitives.py +++ b/torchao/quantization/quant_primitives.py @@ -2274,7 +2274,7 @@ def _expand_scale_to_tensor_shape( def _quantize_affine_float8( tensor: torch.Tensor, scale: torch.Tensor, - float8_dtype: torch.dtype, + float8_dtype: torch.dtype = torch.float8_e4m3fn, ) -> torch.Tensor: """ Quantizes the high precision floating point tensor to a float8 tensor, using the given scaling factor. @@ -2295,7 +2295,7 @@ def _quantize_affine_float8( def _quantize_affine_float8_meta( tensor: torch.Tensor, scale: torch.Tensor, - float8_dtype: torch.dtype, + float8_dtype: torch.dtype = torch.float8_e4m3fn, ) -> torch.Tensor: return torch.empty_like(tensor, dtype=float8_dtype) @@ -2304,7 +2304,7 @@ def _quantize_affine_float8_meta( def _dequantize_affine_float8( tensor: torch.Tensor, scale: torch.Tensor, - output_dtype: torch.dtype, + output_dtype: torch.dtype = torch.float32, ) -> torch.Tensor: """ Dequantizes the float8 tensor to high precision tensor. @@ -2322,6 +2322,6 @@ def _dequantize_affine_float8( def _dequantize_affine_float8_meta( tensor: torch.Tensor, scale: torch.Tensor, - output_dtype: torch.dtype, + output_dtype: torch.dtype = torch.float32, ) -> torch.Tensor: return torch.empty_like(tensor, dtype=output_dtype) From ca662f343e9164afa761e352e2a8e72400501c0a Mon Sep 17 00:00:00 2001 From: wengshiy Date: Wed, 18 Jun 2025 15:47:32 +0000 Subject: [PATCH 04/11] change dispatch key to a flag decomposed --- torchao/utils.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/torchao/utils.py b/torchao/utils.py index 99a0a729f5..4814c7ec63 100644 --- a/torchao/utils.py +++ b/torchao/utils.py @@ -179,7 +179,7 @@ def find_multiple(n: int, *args: int) -> int: return n + k - (n % k) -def _register_custom_op(lib, implicit=True): +def _register_custom_op(lib, decomposed=True): """This decorator is used to preserve some high level operators for torch.export.export while still allow them to be decomposed for inductor path @@ -207,7 +207,7 @@ def _the_op_that_needs_to_be_preserved(...) from torch._inductor.decomposition import register_decomposition dispatch_key = ( - "CompositeImplicitAutograd" if implicit else "CompositeExplicitAutograd" + "CompositeImplicitAutograd" if decomposed else "CompositeExplicitAutograd" ) def decorator(fn): @@ -229,7 +229,7 @@ def decorator(fn): lib_namespace = lib.ns op = getattr(getattr(torch.ops, lib_namespace), op_name) - if implicit: + if decomposed: register_decomposition([op])(fn) return op else: From f51a5bea091e7c6a7c288aa17d8a2484e3f68a65 Mon Sep 17 00:00:00 2001 From: wengshiy Date: Wed, 18 Jun 2025 14:59:33 +0000 Subject: [PATCH 05/11] support scaled_mm on inductor --- .../pt2e/test_x86inductor_fusion.py | 68 +++++ .../quantization/pt2e/inductor_passes/x86.py | 237 ++++++++++++++++++ 2 files changed, 305 insertions(+) diff --git a/test/quantization/pt2e/test_x86inductor_fusion.py b/test/quantization/pt2e/test_x86inductor_fusion.py index ffaa4573d8..579a9583e5 100644 --- a/test/quantization/pt2e/test_x86inductor_fusion.py +++ b/test/quantization/pt2e/test_x86inductor_fusion.py @@ -2427,6 +2427,74 @@ def matcher_check_fn(): self.assertEqual(counters["inductor"]["qlinear_binary_matcher_count"], 1) + @skipIfNoONEDNN + # @parametrize("has_bias", [True, False]) + # @parametrize("dtype", [torch.float32, torch.bfloat16]) + # @parametrize("input_dim_exceeds_two", [True, False]) + @parametrize("has_bias", [True, ]) + @parametrize("dtype", [torch.float32, ]) + @parametrize("input_dim_exceeds_two", [False]) + def test_scaled_mm(self, has_bias, dtype, input_dim_exceeds_two): + class FP8QDQLinear(torch.nn.Module): + def __init__(self, in_features, out_features): + super().__init__() + self.qtype = torch.float8_e4m3fn + self.weight = torch.randn((out_features, in_features)).to(self.qtype) + self.weight_scale = 2.0 + self.scale = 2.0 + self.bias = None + if has_bias: + self.bias = torch.randn((out_features,)).to(dtype) + + def forward(self, input): + weight = torch.ops.torchao.dequantize_affine_float8( + tensor=self.weight.data, + scale=torch.tensor(self.weight_scale), + output_dtype=torch.float + ) + if dtype != torch.float: + weight = weight.to(dtype) + + q_input = torch.ops.torchao.quantize_affine_float8( + tensor=input, + scale=torch.tensor(self.scale), + float8_dtype=self.qtype, + ) + dq_input = torch.ops.torchao.dequantize_affine_float8( + tensor=q_input, + scale=torch.tensor(self.scale), + output_dtype=torch.float + ) + if dtype != torch.float: + dq_input = dq_input.to(dtype) + + out = torch.nn.functional.linear(dq_input, weight, self.bias) + return out + + class Mod(torch.nn.Module): + def __init__(self, in_features, out_features): + super().__init__() + self.l0 = FP8QDQLinear(in_features, out_features) + + def forward(self, x): + y = self.l0(x) + return y + + M1, M2, N, K = 2, 3, 13, 16 + M = M1 * M2 + mod = Mod(N, K) + if input_dim_exceeds_two: + v = torch.randn(M1, M2, N) + else: + v = torch.randn(M, N) + v = v.to(dtype) + + def matcher_check_fn(): + self.assertEqual(counters["inductor"]["scaled_mm_matcher_count"], 1) + + self._test_common(mod, (v,), matcher_check_fn) + + @dynamo_config.patch( { "dynamic_shapes": True, diff --git a/torchao/quantization/pt2e/inductor_passes/x86.py b/torchao/quantization/pt2e/inductor_passes/x86.py index 4ccb2a1f31..a937e38a63 100644 --- a/torchao/quantization/pt2e/inductor_passes/x86.py +++ b/torchao/quantization/pt2e/inductor_passes/x86.py @@ -2740,6 +2740,241 @@ def _register_qlinear_binary_fusion(): ) +def _generate_dequant_fp8_linear_node_pattern(dtype, input_dim_exceeds_two): + # + - - - - | - - - - - - | - - - - + + # | dq_per_tensor dq_per_tensor | + # | | | | + # | OPT(to_bf16) OPT(to_bf16) | + # | | | | + # | OPT(reshape) permute | + # | \ / | + # | addmm/mm | + # | | | + # | OPT(quant_per_tensor) | + # | | | + # | OPT(reshape) | + assert dtype in [torch.float32, torch.bfloat16] + dequant_wgt_pattern = CallFunction( + torch.ops.torchao.dequantize_affine_float8.default, + KeywordArg("q_weight"), + KeywordArg("w_scale"), + KeywordArg("w_dtype"), + ) + t_pattern = CallFunction( + aten.permute.default, + _may_generate_pattern_with_dtype_convert( + dequant_wgt_pattern, + KeywordArg("autocast_wgt_dtype"), + dtype == torch.bfloat16, + ), + KeywordArg("permute_axes"), + ) + dequantize_per_tensor_activation_pattern = CallFunction( + torch.ops.torchao.dequantize_affine_float8.default, + KeywordArg("x"), + KeywordArg("x_scale"), + KeywordArg("x_dq_dtype"), + ) + + dequant_fp8_linear_bias_pattern = _may_generate_pattern_with_reshape( + CallFunction( + aten.addmm.default, + KeywordArg("b"), + _may_generate_pattern_with_reshape( + _may_generate_pattern_with_dtype_convert( + dequantize_per_tensor_activation_pattern, + KeywordArg("autocast_act_dtype"), + dtype == torch.bfloat16, + ), + KeywordArg("act_reshape_size"), + input_dim_exceeds_two, + ), + t_pattern, + ), + KeywordArg("output_reshape_size"), + input_dim_exceeds_two, + ) + dequant_fp8_linear_no_bias_pattern = _may_generate_pattern_with_reshape( + CallFunction( + aten.mm.default, + _may_generate_pattern_with_reshape( + _may_generate_pattern_with_dtype_convert( + dequantize_per_tensor_activation_pattern, + KeywordArg("autocast_act_dtype"), + dtype == torch.bfloat16, + ), + KeywordArg("act_reshape_size"), + input_dim_exceeds_two, + ), + t_pattern, + ), + KeywordArg("output_reshape_size"), + input_dim_exceeds_two, + ) + return dequant_fp8_linear_bias_pattern, dequant_fp8_linear_no_bias_pattern + + +def _is_valid_scaled_mm_pattern(dtype, input_dim_exceeds_two): + def _inner(match): + # input_contiguous = True + # # Check dequant pattern has only 1 user. + # ( + # linear_node, + # _, + # ) = _get_linear_node(match, input_dim_exceeds_two, input_contiguous) + + # input_index = 1 if linear_node.target is aten.addmm.default else 0 + # assert dtype in [torch.float32, torch.bfloat16] + # ( + # dequant_node, + # _, + # _, + # _, + # ) = _get_linear_dq_node( + # linear_node, input_index, dtype, input_dim_exceeds_two, input_contiguous + # ) + # assert dequant_node.target is quantized_decomposed.dequantize_per_tensor.tensor + + # # only support float8_e4m3 input + # if dequant_node.meta["eager_input_vals"][0][0].dtype != torch.float8_e4m3fn: + # return False + + # if len(list(dequant_node.users)) != 1: + # # Ensure the dequant pattern only has 1 user + # # since we will delete the dequant pattern here + # return False + + return True + + return _inner + + +def _register_scaled_mm_pass(pattern, dtype, input_dim_exceeds_two): + @register_freezing_graph_pattern( + pattern, + extra_check=_is_valid_scaled_mm_pattern(dtype, input_dim_exceeds_two), + pass_number=0, + ) + def scaled_mm_fusion(match: Match, *args, **kwargs): + input_contiguous = True + assert dtype in [torch.float32, torch.bfloat16] + ( + linear_node, + output_reshape_node, + ) = _get_linear_node(match, input_dim_exceeds_two, input_contiguous) + input_index = 1 if linear_node.target is aten.addmm.default else 0 + weight_index = input_index + 1 + + ( + dequant_node, + act_reshape_node, + activation_to_bf16_node, + act_expand_node, + ) = _get_linear_dq_node( + linear_node, input_index, dtype, input_dim_exceeds_two, input_contiguous + ) + + if input_dim_exceeds_two and not input_contiguous: + wgt_expand_node = linear_node.args[weight_index] + assert wgt_expand_node.target is aten.expand.default + t_node = wgt_expand_node.args[0] + else: + t_node = linear_node.args[weight_index] + + if dtype == torch.float32: + dequant_per_tensor = t_node.args[0] + else: + weight_to_bf16_node = t_node.args[0] + dequant_per_tensor = weight_to_bf16_node.args[0] + assert ( + dequant_per_tensor.target + is torch.ops.torchao.dequantize_affine_float8.default + ) + + # Activation QParams + qx, x_scale = ( + kwargs["x"], + kwargs["x_scale"], + ) + + # Weight QParams + qw, w_scale = ( + kwargs["q_weight"], + kwargs["w_scale"], + ) + + # Params + bias = kwargs["b"] if "b" in kwargs else None + + x_shape = qx.meta.get("tensor_meta").shape + if has_free_symbols(x_shape): + # For dynamic shape case, we can't get activation shape ahead of runtime. + x_shape = None + graph = match.graph + with graph.inserting_before(linear_node): + scaled_mm_input_node = qx + if input_dim_exceeds_two: + new_reshape_args: tuple[Any, ...] = (qx, act_reshape_node.args[1]) + new_act_reshape_node = graph.call_function( + torch.ops.aten.reshape.default, args=new_reshape_args + ) + scaled_mm_input_node = new_act_reshape_node + # Insert weight prepack node and the qlinear node + permute_weight_inputs = ( + qw, + t_node.args[1], + ) + permute_weight_op = torch.ops.aten.permute.default + permute_weight_node = graph.call_function( + permute_weight_op, args=permute_weight_inputs + ) + output_scale = torch.tensor(1.0) + new_args: tuple[Any, ...] = ( + scaled_mm_input_node, + permute_weight_node, + x_scale, + w_scale, + bias, + output_scale, # output_scale + dtype, # output_dtype + False, # use_fast_accum + ) + new_linear_node = graph.call_function( + torch.ops.aten._scaled_mm.default, args=new_args + ) + + linear_node.replace_all_uses_with(new_linear_node) + new_linear_node.meta.update(linear_node.meta) + + graph.erase_node(linear_node) + if input_dim_exceeds_two: + graph.erase_node(act_reshape_node) + if dtype == torch.bfloat16: + graph.erase_node(activation_to_bf16_node) + # Erase the dequant pattern + graph.erase_node(dequant_node) + # Erase the dequant per channel pattern + graph.erase_node(t_node) + if dtype == torch.bfloat16: + graph.erase_node(weight_to_bf16_node) # type: ignore[possibly-undefined] + graph.erase_node(dequant_per_tensor) + + counters["inductor"]["scaled_mm_matcher_count"] += 1 + counters["inductor"]["scaled_mm_matcher_nodes"] += len(match.nodes) + + +def _register_scaled_mm(): + fp8_linear_weight_prepack_cases = itertools.product( + [torch.float32, torch.bfloat16], [False, True] + ) + for dtype, input_dim_exceeds_two in fp8_linear_weight_prepack_cases: + patterns = _generate_dequant_fp8_linear_node_pattern( + dtype, input_dim_exceeds_two + ) + for pattern in patterns: + _register_scaled_mm_pass(pattern, dtype, input_dim_exceeds_two) + + @functools.lru_cache(None) def _register_quantization_weight_pack_pass(): # Step 1: Dequant promotion for int8-mixed-fp32/bf16 @@ -2763,6 +2998,8 @@ def _register_quantization_weight_pack_pass(): _register_qlinear_unary_fusion() _register_qlinear_binary_fusion() + _register_scaled_mm() + def quant_lift_up(module_graph: torch.fx.graph.Graph): """ From 719793c64420edbfb5c0a5eb7d3a2a0221a7c1f2 Mon Sep 17 00:00:00 2001 From: wengshiy Date: Wed, 18 Jun 2025 16:48:47 +0000 Subject: [PATCH 06/11] fix rebase issue --- .../pt2e/test_x86inductor_fusion.py | 14 ++--- .../quantization/pt2e/inductor_passes/x86.py | 58 +++++++++---------- 2 files changed, 34 insertions(+), 38 deletions(-) diff --git a/test/quantization/pt2e/test_x86inductor_fusion.py b/test/quantization/pt2e/test_x86inductor_fusion.py index 579a9583e5..43868bfed0 100644 --- a/test/quantization/pt2e/test_x86inductor_fusion.py +++ b/test/quantization/pt2e/test_x86inductor_fusion.py @@ -2426,14 +2426,10 @@ def matcher_check_fn(): if test_for_pointwise_binary: self.assertEqual(counters["inductor"]["qlinear_binary_matcher_count"], 1) - @skipIfNoONEDNN - # @parametrize("has_bias", [True, False]) - # @parametrize("dtype", [torch.float32, torch.bfloat16]) - # @parametrize("input_dim_exceeds_two", [True, False]) - @parametrize("has_bias", [True, ]) - @parametrize("dtype", [torch.float32, ]) - @parametrize("input_dim_exceeds_two", [False]) + @parametrize("has_bias", [True, False]) + @parametrize("dtype", [torch.float32, torch.bfloat16]) + @parametrize("input_dim_exceeds_two", [True, False]) def test_scaled_mm(self, has_bias, dtype, input_dim_exceeds_two): class FP8QDQLinear(torch.nn.Module): def __init__(self, in_features, out_features): @@ -2450,7 +2446,7 @@ def forward(self, input): weight = torch.ops.torchao.dequantize_affine_float8( tensor=self.weight.data, scale=torch.tensor(self.weight_scale), - output_dtype=torch.float + output_dtype=torch.float, ) if dtype != torch.float: weight = weight.to(dtype) @@ -2463,7 +2459,7 @@ def forward(self, input): dq_input = torch.ops.torchao.dequantize_affine_float8( tensor=q_input, scale=torch.tensor(self.scale), - output_dtype=torch.float + output_dtype=torch.float, ) if dtype != torch.float: dq_input = dq_input.to(dtype) diff --git a/torchao/quantization/pt2e/inductor_passes/x86.py b/torchao/quantization/pt2e/inductor_passes/x86.py index a937e38a63..1dfb2ce3d2 100644 --- a/torchao/quantization/pt2e/inductor_passes/x86.py +++ b/torchao/quantization/pt2e/inductor_passes/x86.py @@ -2758,7 +2758,7 @@ def _generate_dequant_fp8_linear_node_pattern(dtype, input_dim_exceeds_two): torch.ops.torchao.dequantize_affine_float8.default, KeywordArg("q_weight"), KeywordArg("w_scale"), - KeywordArg("w_dtype"), + output_dtype=KeywordArg("w_dtype"), ) t_pattern = CallFunction( aten.permute.default, @@ -2773,7 +2773,7 @@ def _generate_dequant_fp8_linear_node_pattern(dtype, input_dim_exceeds_two): torch.ops.torchao.dequantize_affine_float8.default, KeywordArg("x"), KeywordArg("x_scale"), - KeywordArg("x_dq_dtype"), + output_dtype=KeywordArg("x_dq_dtype"), ) dequant_fp8_linear_bias_pattern = _may_generate_pattern_with_reshape( @@ -2816,33 +2816,33 @@ def _generate_dequant_fp8_linear_node_pattern(dtype, input_dim_exceeds_two): def _is_valid_scaled_mm_pattern(dtype, input_dim_exceeds_two): def _inner(match): - # input_contiguous = True - # # Check dequant pattern has only 1 user. - # ( - # linear_node, - # _, - # ) = _get_linear_node(match, input_dim_exceeds_two, input_contiguous) - - # input_index = 1 if linear_node.target is aten.addmm.default else 0 - # assert dtype in [torch.float32, torch.bfloat16] - # ( - # dequant_node, - # _, - # _, - # _, - # ) = _get_linear_dq_node( - # linear_node, input_index, dtype, input_dim_exceeds_two, input_contiguous - # ) - # assert dequant_node.target is quantized_decomposed.dequantize_per_tensor.tensor - - # # only support float8_e4m3 input - # if dequant_node.meta["eager_input_vals"][0][0].dtype != torch.float8_e4m3fn: - # return False - - # if len(list(dequant_node.users)) != 1: - # # Ensure the dequant pattern only has 1 user - # # since we will delete the dequant pattern here - # return False + input_contiguous = True + # Check dequant pattern has only 1 user. + ( + linear_node, + _, + ) = _get_linear_node(match, input_dim_exceeds_two, input_contiguous) + + input_index = 1 if linear_node.target is aten.addmm.default else 0 + assert dtype in [torch.float32, torch.bfloat16] + ( + dequant_node, + _, + _, + _, + ) = _get_linear_dq_node( + linear_node, input_index, dtype, input_dim_exceeds_two, input_contiguous + ) + assert dequant_node.target is torch.ops.torchao.dequantize_affine_float8.default + + # only support float8_e4m3 input + if dequant_node.meta["eager_input_vals"][0][0].dtype != torch.float8_e4m3fn: + return False + + if len(list(dequant_node.users)) != 1: + # Ensure the dequant pattern only has 1 user + # since we will delete the dequant pattern here + return False return True From 48a3d999eba50277b735a00848f62a008f9f2376 Mon Sep 17 00:00:00 2001 From: wengshiy Date: Wed, 25 Jun 2025 10:06:21 +0000 Subject: [PATCH 07/11] support dequant promtion for fp8 --- .../quantization/pt2e/inductor_passes/x86.py | 52 +++++++++++++------ 1 file changed, 35 insertions(+), 17 deletions(-) diff --git a/torchao/quantization/pt2e/inductor_passes/x86.py b/torchao/quantization/pt2e/inductor_passes/x86.py index 1dfb2ce3d2..e5cfd95f71 100644 --- a/torchao/quantization/pt2e/inductor_passes/x86.py +++ b/torchao/quantization/pt2e/inductor_passes/x86.py @@ -116,18 +116,26 @@ def _unary_fusion_pattern(unary_fusion, call_fn, users, is_bf16): return unary_fusion(computation_call) -def get_dequantize_per_tensor_activation_pattern(is_tensor_overload=False): - dequantize_per_tensor_activation_pattern = CallFunction( - quantized_decomposed.dequantize_per_tensor.tensor - if is_tensor_overload - else quantized_decomposed.dequantize_per_tensor.default, - KeywordArg("x"), - KeywordArg("x_scale"), - KeywordArg("x_zp"), - KeywordArg("x_quant_min"), - KeywordArg("x_quant_max"), - KeywordArg("x_dq_dtype"), - ) +def get_dequantize_per_tensor_activation_pattern(is_tensor_overload=False, is_fp8=False): + if is_fp8: + dequantize_per_tensor_activation_pattern = CallFunction( + torch.ops.torchao.dequantize_affine_float8.default, + KeywordArg("x"), + KeywordArg("x_scale"), + output_dtype=KeywordArg("x_dq_dtype"), + ) + else: + dequantize_per_tensor_activation_pattern = CallFunction( + quantized_decomposed.dequantize_per_tensor.tensor + if is_tensor_overload + else quantized_decomposed.dequantize_per_tensor.default, + KeywordArg("x"), + KeywordArg("x_scale"), + KeywordArg("x_zp"), + KeywordArg("x_quant_min"), + KeywordArg("x_quant_max"), + KeywordArg("x_dq_dtype"), + ) return dequantize_per_tensor_activation_pattern @@ -491,6 +499,7 @@ def _inner(match): if dequant_pattern_end_node.target not in [ quantized_decomposed.dequantize_per_tensor.default, quantized_decomposed.dequantize_per_tensor.tensor, + torch.ops.torchao.dequantize_affine_float8.default, prims.convert_element_type.default, aten.reshape.default, ]: @@ -520,6 +529,7 @@ def _inner(match): in [ quantized_decomposed.dequantize_per_tensor.default, quantized_decomposed.dequantize_per_tensor.tensor, + torch.ops.torchao.dequantize_affine_float8.default, ] and len(list(dequant_pattern_end_node.users)) > 1 ): @@ -530,7 +540,7 @@ def _inner(match): return _inner -def _register_dequant_promotion_pass(pattern, pass_number, dtype=torch.float32): +def _register_dequant_promotion_pass(pattern, pass_number, dtype=torch.float32, is_fp8=False): @register_freezing_graph_pattern( pattern, extra_check=_is_valid_dequant_promotion_pattern(dtype), @@ -586,6 +596,7 @@ def clone_to_new_node(graph, source_node, user_node): assert dequant_pattern_end_node.target in [ quantized_decomposed.dequantize_per_tensor.default, quantized_decomposed.dequantize_per_tensor.tensor, + torch.ops.torchao.dequantize_affine_float8.default, prims.convert_element_type.default, aten.reshape.default, ] @@ -598,6 +609,7 @@ def _find_first_node_in_dequant_pattern(_node): if _node.target in [ quantized_decomposed.dequantize_per_tensor.default, quantized_decomposed.dequantize_per_tensor.tensor, + torch.ops.torchao.dequantize_affine_float8.default, ]: # For a dequant pattern, we expect the start node is a dequantize_per_tensor node return _node @@ -614,6 +626,7 @@ def _find_first_node_in_dequant_pattern(_node): assert dequant_pattern_start_node.target in [ quantized_decomposed.dequantize_per_tensor.default, quantized_decomposed.dequantize_per_tensor.tensor, + torch.ops.torchao.dequantize_affine_float8.default, ] # Clone the dequant pattern for each user node @@ -1332,9 +1345,9 @@ def _generate_linear_dynamic_fp16_pattern( def _register_dequant_promotion(): dequant_pattern_cases = itertools.product( - [torch.float32, torch.bfloat16], [True, False], [True, False] + [torch.float32, torch.bfloat16], [True, False], [True, False], [True, False] ) - for dtype, input_dim_exceeds_two, is_tensor_overload in dequant_pattern_cases: + for dtype, input_dim_exceeds_two, is_tensor_overload, is_fp8 in dequant_pattern_cases: # 4 dequantization patterns will be matched based on the dtype and input dimension size. # Case 1: int8-mixed-fp32, input dim size is 2 # Case 2: int8-mixed-fp32, input dim size exceeds 2 @@ -1355,11 +1368,15 @@ def _register_dequant_promotion(): # OPT(to_fp32) OPT(to_fp32) # + - - | - - - | - - + # quant quant + if is_fp8 and not is_tensor_overload: + continue + _register_dequant_promotion_pass( _may_generate_pattern_with_reshape( _may_generate_pattern_with_dtype_convert( get_dequantize_per_tensor_activation_pattern( - is_tensor_overload=is_tensor_overload + is_tensor_overload=is_tensor_overload, + is_fp8=is_fp8, ), KeywordArg("autocast_act_dtype"), dtype == torch.bfloat16, @@ -1369,6 +1386,7 @@ def _register_dequant_promotion(): ), pass_number=0, dtype=dtype, + is_fp8=is_fp8, ) # pass_number=0 to run before weight prepack @@ -2853,7 +2871,7 @@ def _register_scaled_mm_pass(pattern, dtype, input_dim_exceeds_two): @register_freezing_graph_pattern( pattern, extra_check=_is_valid_scaled_mm_pattern(dtype, input_dim_exceeds_two), - pass_number=0, + pass_number=1, ) def scaled_mm_fusion(match: Match, *args, **kwargs): input_contiguous = True From 1921b2f8d7a3340caf6345a2a89d56968abc7f73 Mon Sep 17 00:00:00 2001 From: wengshiy Date: Wed, 25 Jun 2025 14:19:34 +0000 Subject: [PATCH 08/11] add ut --- .../quantization/pt2e/test_x86inductor_fusion.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/test/quantization/pt2e/test_x86inductor_fusion.py b/test/quantization/pt2e/test_x86inductor_fusion.py index 43868bfed0..fa981dc4d6 100644 --- a/test/quantization/pt2e/test_x86inductor_fusion.py +++ b/test/quantization/pt2e/test_x86inductor_fusion.py @@ -2430,7 +2430,8 @@ def matcher_check_fn(): @parametrize("has_bias", [True, False]) @parametrize("dtype", [torch.float32, torch.bfloat16]) @parametrize("input_dim_exceeds_two", [True, False]) - def test_scaled_mm(self, has_bias, dtype, input_dim_exceeds_two): + @parametrize("check_reuse_input", [True, False]) + def test_scaled_mm(self, has_bias, dtype, input_dim_exceeds_two, check_reuse_input): class FP8QDQLinear(torch.nn.Module): def __init__(self, in_features, out_features): super().__init__() @@ -2468,17 +2469,23 @@ def forward(self, input): return out class Mod(torch.nn.Module): - def __init__(self, in_features, out_features): + def __init__(self, in_features, out_features, check_reuse_input): super().__init__() self.l0 = FP8QDQLinear(in_features, out_features) + self.check_reuse_input = check_reuse_input + if self.check_reuse_input: + self.l1 = FP8QDQLinear(in_features, out_features) def forward(self, x): y = self.l0(x) + if self.check_reuse_input: + z = self.l1(x) + y += z return y M1, M2, N, K = 2, 3, 13, 16 M = M1 * M2 - mod = Mod(N, K) + mod = Mod(N, K, check_reuse_input) if input_dim_exceeds_two: v = torch.randn(M1, M2, N) else: @@ -2486,7 +2493,8 @@ def forward(self, x): v = v.to(dtype) def matcher_check_fn(): - self.assertEqual(counters["inductor"]["scaled_mm_matcher_count"], 1) + counter = 2 if check_reuse_input else 1 + self.assertEqual(counters["inductor"]["scaled_mm_matcher_count"], counter) self._test_common(mod, (v,), matcher_check_fn) From 0335415f8449bfa59613a1832184fa74999cf153 Mon Sep 17 00:00:00 2001 From: wengshiy Date: Wed, 25 Jun 2025 14:35:38 +0000 Subject: [PATCH 09/11] remove redundant codes --- torchao/quantization/pt2e/inductor_passes/x86.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/torchao/quantization/pt2e/inductor_passes/x86.py b/torchao/quantization/pt2e/inductor_passes/x86.py index e5cfd95f71..0af7d21a99 100644 --- a/torchao/quantization/pt2e/inductor_passes/x86.py +++ b/torchao/quantization/pt2e/inductor_passes/x86.py @@ -540,7 +540,7 @@ def _inner(match): return _inner -def _register_dequant_promotion_pass(pattern, pass_number, dtype=torch.float32, is_fp8=False): +def _register_dequant_promotion_pass(pattern, pass_number, dtype=torch.float32): @register_freezing_graph_pattern( pattern, extra_check=_is_valid_dequant_promotion_pattern(dtype), @@ -1386,7 +1386,6 @@ def _register_dequant_promotion(): ), pass_number=0, dtype=dtype, - is_fp8=is_fp8, ) # pass_number=0 to run before weight prepack From a5bb4d0cece16b2beaf36323e7b5440be9d9a198 Mon Sep 17 00:00:00 2001 From: wengshiy Date: Wed, 25 Jun 2025 16:24:35 +0000 Subject: [PATCH 10/11] fix lint --- torchao/quantization/pt2e/inductor_passes/x86.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/torchao/quantization/pt2e/inductor_passes/x86.py b/torchao/quantization/pt2e/inductor_passes/x86.py index 0af7d21a99..dd9f0e6c21 100644 --- a/torchao/quantization/pt2e/inductor_passes/x86.py +++ b/torchao/quantization/pt2e/inductor_passes/x86.py @@ -116,7 +116,9 @@ def _unary_fusion_pattern(unary_fusion, call_fn, users, is_bf16): return unary_fusion(computation_call) -def get_dequantize_per_tensor_activation_pattern(is_tensor_overload=False, is_fp8=False): +def get_dequantize_per_tensor_activation_pattern( + is_tensor_overload=False, is_fp8=False +): if is_fp8: dequantize_per_tensor_activation_pattern = CallFunction( torch.ops.torchao.dequantize_affine_float8.default, @@ -1347,7 +1349,12 @@ def _register_dequant_promotion(): dequant_pattern_cases = itertools.product( [torch.float32, torch.bfloat16], [True, False], [True, False], [True, False] ) - for dtype, input_dim_exceeds_two, is_tensor_overload, is_fp8 in dequant_pattern_cases: + for ( + dtype, + input_dim_exceeds_two, + is_tensor_overload, + is_fp8, + ) in dequant_pattern_cases: # 4 dequantization patterns will be matched based on the dtype and input dimension size. # Case 1: int8-mixed-fp32, input dim size is 2 # Case 2: int8-mixed-fp32, input dim size exceeds 2 From 0c7f8eacaf98d181d3cb60c775af9d503dae736b Mon Sep 17 00:00:00 2001 From: wengshiy Date: Wed, 25 Jun 2025 16:30:37 +0000 Subject: [PATCH 11/11] resolve conflict --- test/float8/test_compile.py | 54 ------------------------------------- 1 file changed, 54 deletions(-) diff --git a/test/float8/test_compile.py b/test/float8/test_compile.py index c9280cd70c..aaf9d3d3f5 100644 --- a/test/float8/test_compile.py +++ b/test/float8/test_compile.py @@ -392,59 +392,5 @@ def test_dynamic_scale_numeric_parity( assert torch.equal(float8_eager._data, float8_compile._data) -@pytest.mark.parametrize( - "float8_dtype", - [ - torch.float8_e4m3fn, - torch.float8_e5m2, - ], -) -@pytest.mark.parametrize( - "hp_dtype", - [ - torch.float32, - torch.float16, - torch.bfloat16, - ], -) -def test_quantize_dequantize_fp8_inductor(float8_dtype, hp_dtype): - quantize_affine_float8 = torch.ops.torchao.quantize_affine_float8 - dequantize_affine_float8 = torch.ops.torchao.dequantize_affine_float8 - input = torch.randn(10, 10) - with torch.no_grad(): - torch._dynamo.reset() - expected_scale = torch.tensor(2.0) - expected_quantized = quantize_affine_float8( - input, - expected_scale, - float8_dtype=float8_dtype, - ) - expected_dequantized = dequantize_affine_float8( - expected_quantized, - expected_scale, - output_dtype=hp_dtype, - ) - test_q, (code_q,) = torch._inductor.utils.run_and_get_code( - torch.compile(quantize_affine_float8), - input, - expected_scale, - float8_dtype=float8_dtype, - ) - torch.testing.FileCheck().check( - "torch.ops.torchao.quantize_affine_float8.default" - ).run(code_q) - test_dq, (code_dq,) = torch._inductor.utils.run_and_get_code( - torch.compile(dequantize_affine_float8), - test_q, - expected_scale, - hp_dtype, - ) - torch.testing.FileCheck().check( - "torch.ops.torchao.dequantize_affine_float8.default" - ).run(code_dq) - torch.testing.assert_close(expected_quantized, test_q) - torch.testing.assert_close(expected_dequantized, test_dq) - - if __name__ == "__main__": pytest.main([__file__])