diff --git a/test/quantization/pt2e/test_x86inductor_fusion.py b/test/quantization/pt2e/test_x86inductor_fusion.py index ffaa4573d8..fa981dc4d6 100644 --- a/test/quantization/pt2e/test_x86inductor_fusion.py +++ b/test/quantization/pt2e/test_x86inductor_fusion.py @@ -2426,6 +2426,78 @@ def matcher_check_fn(): if test_for_pointwise_binary: self.assertEqual(counters["inductor"]["qlinear_binary_matcher_count"], 1) + @skipIfNoONEDNN + @parametrize("has_bias", [True, False]) + @parametrize("dtype", [torch.float32, torch.bfloat16]) + @parametrize("input_dim_exceeds_two", [True, False]) + @parametrize("check_reuse_input", [True, False]) + def test_scaled_mm(self, has_bias, dtype, input_dim_exceeds_two, check_reuse_input): + class FP8QDQLinear(torch.nn.Module): + def __init__(self, in_features, out_features): + super().__init__() + self.qtype = torch.float8_e4m3fn + self.weight = torch.randn((out_features, in_features)).to(self.qtype) + self.weight_scale = 2.0 + self.scale = 2.0 + self.bias = None + if has_bias: + self.bias = torch.randn((out_features,)).to(dtype) + + def forward(self, input): + weight = torch.ops.torchao.dequantize_affine_float8( + tensor=self.weight.data, + scale=torch.tensor(self.weight_scale), + output_dtype=torch.float, + ) + if dtype != torch.float: + weight = weight.to(dtype) + + q_input = torch.ops.torchao.quantize_affine_float8( + tensor=input, + scale=torch.tensor(self.scale), + float8_dtype=self.qtype, + ) + dq_input = torch.ops.torchao.dequantize_affine_float8( + tensor=q_input, + scale=torch.tensor(self.scale), + output_dtype=torch.float, + ) + if dtype != torch.float: + dq_input = dq_input.to(dtype) + + out = torch.nn.functional.linear(dq_input, weight, self.bias) + return out + + class Mod(torch.nn.Module): + def __init__(self, in_features, out_features, check_reuse_input): + super().__init__() + self.l0 = FP8QDQLinear(in_features, out_features) + self.check_reuse_input = check_reuse_input + if self.check_reuse_input: + self.l1 = FP8QDQLinear(in_features, out_features) + + def forward(self, x): + y = self.l0(x) + if self.check_reuse_input: + z = self.l1(x) + y += z + return y + + M1, M2, N, K = 2, 3, 13, 16 + M = M1 * M2 + mod = Mod(N, K, check_reuse_input) + if input_dim_exceeds_two: + v = torch.randn(M1, M2, N) + else: + v = torch.randn(M, N) + v = v.to(dtype) + + def matcher_check_fn(): + counter = 2 if check_reuse_input else 1 + self.assertEqual(counters["inductor"]["scaled_mm_matcher_count"], counter) + + self._test_common(mod, (v,), matcher_check_fn) + @dynamo_config.patch( { diff --git a/torchao/quantization/pt2e/inductor_passes/x86.py b/torchao/quantization/pt2e/inductor_passes/x86.py index 4ccb2a1f31..dd9f0e6c21 100644 --- a/torchao/quantization/pt2e/inductor_passes/x86.py +++ b/torchao/quantization/pt2e/inductor_passes/x86.py @@ -116,18 +116,28 @@ def _unary_fusion_pattern(unary_fusion, call_fn, users, is_bf16): return unary_fusion(computation_call) -def get_dequantize_per_tensor_activation_pattern(is_tensor_overload=False): - dequantize_per_tensor_activation_pattern = CallFunction( - quantized_decomposed.dequantize_per_tensor.tensor - if is_tensor_overload - else quantized_decomposed.dequantize_per_tensor.default, - KeywordArg("x"), - KeywordArg("x_scale"), - KeywordArg("x_zp"), - KeywordArg("x_quant_min"), - KeywordArg("x_quant_max"), - KeywordArg("x_dq_dtype"), - ) +def get_dequantize_per_tensor_activation_pattern( + is_tensor_overload=False, is_fp8=False +): + if is_fp8: + dequantize_per_tensor_activation_pattern = CallFunction( + torch.ops.torchao.dequantize_affine_float8.default, + KeywordArg("x"), + KeywordArg("x_scale"), + output_dtype=KeywordArg("x_dq_dtype"), + ) + else: + dequantize_per_tensor_activation_pattern = CallFunction( + quantized_decomposed.dequantize_per_tensor.tensor + if is_tensor_overload + else quantized_decomposed.dequantize_per_tensor.default, + KeywordArg("x"), + KeywordArg("x_scale"), + KeywordArg("x_zp"), + KeywordArg("x_quant_min"), + KeywordArg("x_quant_max"), + KeywordArg("x_dq_dtype"), + ) return dequantize_per_tensor_activation_pattern @@ -491,6 +501,7 @@ def _inner(match): if dequant_pattern_end_node.target not in [ quantized_decomposed.dequantize_per_tensor.default, quantized_decomposed.dequantize_per_tensor.tensor, + torch.ops.torchao.dequantize_affine_float8.default, prims.convert_element_type.default, aten.reshape.default, ]: @@ -520,6 +531,7 @@ def _inner(match): in [ quantized_decomposed.dequantize_per_tensor.default, quantized_decomposed.dequantize_per_tensor.tensor, + torch.ops.torchao.dequantize_affine_float8.default, ] and len(list(dequant_pattern_end_node.users)) > 1 ): @@ -586,6 +598,7 @@ def clone_to_new_node(graph, source_node, user_node): assert dequant_pattern_end_node.target in [ quantized_decomposed.dequantize_per_tensor.default, quantized_decomposed.dequantize_per_tensor.tensor, + torch.ops.torchao.dequantize_affine_float8.default, prims.convert_element_type.default, aten.reshape.default, ] @@ -598,6 +611,7 @@ def _find_first_node_in_dequant_pattern(_node): if _node.target in [ quantized_decomposed.dequantize_per_tensor.default, quantized_decomposed.dequantize_per_tensor.tensor, + torch.ops.torchao.dequantize_affine_float8.default, ]: # For a dequant pattern, we expect the start node is a dequantize_per_tensor node return _node @@ -614,6 +628,7 @@ def _find_first_node_in_dequant_pattern(_node): assert dequant_pattern_start_node.target in [ quantized_decomposed.dequantize_per_tensor.default, quantized_decomposed.dequantize_per_tensor.tensor, + torch.ops.torchao.dequantize_affine_float8.default, ] # Clone the dequant pattern for each user node @@ -1332,9 +1347,14 @@ def _generate_linear_dynamic_fp16_pattern( def _register_dequant_promotion(): dequant_pattern_cases = itertools.product( - [torch.float32, torch.bfloat16], [True, False], [True, False] + [torch.float32, torch.bfloat16], [True, False], [True, False], [True, False] ) - for dtype, input_dim_exceeds_two, is_tensor_overload in dequant_pattern_cases: + for ( + dtype, + input_dim_exceeds_two, + is_tensor_overload, + is_fp8, + ) in dequant_pattern_cases: # 4 dequantization patterns will be matched based on the dtype and input dimension size. # Case 1: int8-mixed-fp32, input dim size is 2 # Case 2: int8-mixed-fp32, input dim size exceeds 2 @@ -1355,11 +1375,15 @@ def _register_dequant_promotion(): # OPT(to_fp32) OPT(to_fp32) # + - - | - - - | - - + # quant quant + if is_fp8 and not is_tensor_overload: + continue + _register_dequant_promotion_pass( _may_generate_pattern_with_reshape( _may_generate_pattern_with_dtype_convert( get_dequantize_per_tensor_activation_pattern( - is_tensor_overload=is_tensor_overload + is_tensor_overload=is_tensor_overload, + is_fp8=is_fp8, ), KeywordArg("autocast_act_dtype"), dtype == torch.bfloat16, @@ -2740,6 +2764,241 @@ def _register_qlinear_binary_fusion(): ) +def _generate_dequant_fp8_linear_node_pattern(dtype, input_dim_exceeds_two): + # + - - - - | - - - - - - | - - - - + + # | dq_per_tensor dq_per_tensor | + # | | | | + # | OPT(to_bf16) OPT(to_bf16) | + # | | | | + # | OPT(reshape) permute | + # | \ / | + # | addmm/mm | + # | | | + # | OPT(quant_per_tensor) | + # | | | + # | OPT(reshape) | + assert dtype in [torch.float32, torch.bfloat16] + dequant_wgt_pattern = CallFunction( + torch.ops.torchao.dequantize_affine_float8.default, + KeywordArg("q_weight"), + KeywordArg("w_scale"), + output_dtype=KeywordArg("w_dtype"), + ) + t_pattern = CallFunction( + aten.permute.default, + _may_generate_pattern_with_dtype_convert( + dequant_wgt_pattern, + KeywordArg("autocast_wgt_dtype"), + dtype == torch.bfloat16, + ), + KeywordArg("permute_axes"), + ) + dequantize_per_tensor_activation_pattern = CallFunction( + torch.ops.torchao.dequantize_affine_float8.default, + KeywordArg("x"), + KeywordArg("x_scale"), + output_dtype=KeywordArg("x_dq_dtype"), + ) + + dequant_fp8_linear_bias_pattern = _may_generate_pattern_with_reshape( + CallFunction( + aten.addmm.default, + KeywordArg("b"), + _may_generate_pattern_with_reshape( + _may_generate_pattern_with_dtype_convert( + dequantize_per_tensor_activation_pattern, + KeywordArg("autocast_act_dtype"), + dtype == torch.bfloat16, + ), + KeywordArg("act_reshape_size"), + input_dim_exceeds_two, + ), + t_pattern, + ), + KeywordArg("output_reshape_size"), + input_dim_exceeds_two, + ) + dequant_fp8_linear_no_bias_pattern = _may_generate_pattern_with_reshape( + CallFunction( + aten.mm.default, + _may_generate_pattern_with_reshape( + _may_generate_pattern_with_dtype_convert( + dequantize_per_tensor_activation_pattern, + KeywordArg("autocast_act_dtype"), + dtype == torch.bfloat16, + ), + KeywordArg("act_reshape_size"), + input_dim_exceeds_two, + ), + t_pattern, + ), + KeywordArg("output_reshape_size"), + input_dim_exceeds_two, + ) + return dequant_fp8_linear_bias_pattern, dequant_fp8_linear_no_bias_pattern + + +def _is_valid_scaled_mm_pattern(dtype, input_dim_exceeds_two): + def _inner(match): + input_contiguous = True + # Check dequant pattern has only 1 user. + ( + linear_node, + _, + ) = _get_linear_node(match, input_dim_exceeds_two, input_contiguous) + + input_index = 1 if linear_node.target is aten.addmm.default else 0 + assert dtype in [torch.float32, torch.bfloat16] + ( + dequant_node, + _, + _, + _, + ) = _get_linear_dq_node( + linear_node, input_index, dtype, input_dim_exceeds_two, input_contiguous + ) + assert dequant_node.target is torch.ops.torchao.dequantize_affine_float8.default + + # only support float8_e4m3 input + if dequant_node.meta["eager_input_vals"][0][0].dtype != torch.float8_e4m3fn: + return False + + if len(list(dequant_node.users)) != 1: + # Ensure the dequant pattern only has 1 user + # since we will delete the dequant pattern here + return False + + return True + + return _inner + + +def _register_scaled_mm_pass(pattern, dtype, input_dim_exceeds_two): + @register_freezing_graph_pattern( + pattern, + extra_check=_is_valid_scaled_mm_pattern(dtype, input_dim_exceeds_two), + pass_number=1, + ) + def scaled_mm_fusion(match: Match, *args, **kwargs): + input_contiguous = True + assert dtype in [torch.float32, torch.bfloat16] + ( + linear_node, + output_reshape_node, + ) = _get_linear_node(match, input_dim_exceeds_two, input_contiguous) + input_index = 1 if linear_node.target is aten.addmm.default else 0 + weight_index = input_index + 1 + + ( + dequant_node, + act_reshape_node, + activation_to_bf16_node, + act_expand_node, + ) = _get_linear_dq_node( + linear_node, input_index, dtype, input_dim_exceeds_two, input_contiguous + ) + + if input_dim_exceeds_two and not input_contiguous: + wgt_expand_node = linear_node.args[weight_index] + assert wgt_expand_node.target is aten.expand.default + t_node = wgt_expand_node.args[0] + else: + t_node = linear_node.args[weight_index] + + if dtype == torch.float32: + dequant_per_tensor = t_node.args[0] + else: + weight_to_bf16_node = t_node.args[0] + dequant_per_tensor = weight_to_bf16_node.args[0] + assert ( + dequant_per_tensor.target + is torch.ops.torchao.dequantize_affine_float8.default + ) + + # Activation QParams + qx, x_scale = ( + kwargs["x"], + kwargs["x_scale"], + ) + + # Weight QParams + qw, w_scale = ( + kwargs["q_weight"], + kwargs["w_scale"], + ) + + # Params + bias = kwargs["b"] if "b" in kwargs else None + + x_shape = qx.meta.get("tensor_meta").shape + if has_free_symbols(x_shape): + # For dynamic shape case, we can't get activation shape ahead of runtime. + x_shape = None + graph = match.graph + with graph.inserting_before(linear_node): + scaled_mm_input_node = qx + if input_dim_exceeds_two: + new_reshape_args: tuple[Any, ...] = (qx, act_reshape_node.args[1]) + new_act_reshape_node = graph.call_function( + torch.ops.aten.reshape.default, args=new_reshape_args + ) + scaled_mm_input_node = new_act_reshape_node + # Insert weight prepack node and the qlinear node + permute_weight_inputs = ( + qw, + t_node.args[1], + ) + permute_weight_op = torch.ops.aten.permute.default + permute_weight_node = graph.call_function( + permute_weight_op, args=permute_weight_inputs + ) + output_scale = torch.tensor(1.0) + new_args: tuple[Any, ...] = ( + scaled_mm_input_node, + permute_weight_node, + x_scale, + w_scale, + bias, + output_scale, # output_scale + dtype, # output_dtype + False, # use_fast_accum + ) + new_linear_node = graph.call_function( + torch.ops.aten._scaled_mm.default, args=new_args + ) + + linear_node.replace_all_uses_with(new_linear_node) + new_linear_node.meta.update(linear_node.meta) + + graph.erase_node(linear_node) + if input_dim_exceeds_two: + graph.erase_node(act_reshape_node) + if dtype == torch.bfloat16: + graph.erase_node(activation_to_bf16_node) + # Erase the dequant pattern + graph.erase_node(dequant_node) + # Erase the dequant per channel pattern + graph.erase_node(t_node) + if dtype == torch.bfloat16: + graph.erase_node(weight_to_bf16_node) # type: ignore[possibly-undefined] + graph.erase_node(dequant_per_tensor) + + counters["inductor"]["scaled_mm_matcher_count"] += 1 + counters["inductor"]["scaled_mm_matcher_nodes"] += len(match.nodes) + + +def _register_scaled_mm(): + fp8_linear_weight_prepack_cases = itertools.product( + [torch.float32, torch.bfloat16], [False, True] + ) + for dtype, input_dim_exceeds_two in fp8_linear_weight_prepack_cases: + patterns = _generate_dequant_fp8_linear_node_pattern( + dtype, input_dim_exceeds_two + ) + for pattern in patterns: + _register_scaled_mm_pass(pattern, dtype, input_dim_exceeds_two) + + @functools.lru_cache(None) def _register_quantization_weight_pack_pass(): # Step 1: Dequant promotion for int8-mixed-fp32/bf16 @@ -2763,6 +3022,8 @@ def _register_quantization_weight_pack_pass(): _register_qlinear_unary_fusion() _register_qlinear_binary_fusion() + _register_scaled_mm() + def quant_lift_up(module_graph: torch.fx.graph.Graph): """