-
Notifications
You must be signed in to change notification settings - Fork 305
[Inductor] Support scaled mm on inductor #2411
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 6 commits
a840ef5
02d045b
9860c56
ca662f3
f51a5be
719793c
48a3d99
1921b2f
0335415
955fa6e
a70e094
a5bb4d0
1c1f890
0c7f8ea
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2740,6 +2740,241 @@ def _register_qlinear_binary_fusion(): | |
) | ||
|
||
|
||
def _generate_dequant_fp8_linear_node_pattern(dtype, input_dim_exceeds_two): | ||
# + - - - - | - - - - - - | - - - - + | ||
# | dq_per_tensor dq_per_tensor | | ||
# | | | | | ||
# | OPT(to_bf16) OPT(to_bf16) | | ||
# | | | | | ||
# | OPT(reshape) permute | | ||
# | \ / | | ||
# | addmm/mm | | ||
# | | | | ||
# | OPT(quant_per_tensor) | | ||
# | | | | ||
# | OPT(reshape) | | ||
assert dtype in [torch.float32, torch.bfloat16] | ||
dequant_wgt_pattern = CallFunction( | ||
torch.ops.torchao.dequantize_affine_float8.default, | ||
KeywordArg("q_weight"), | ||
KeywordArg("w_scale"), | ||
output_dtype=KeywordArg("w_dtype"), | ||
) | ||
t_pattern = CallFunction( | ||
aten.permute.default, | ||
_may_generate_pattern_with_dtype_convert( | ||
dequant_wgt_pattern, | ||
KeywordArg("autocast_wgt_dtype"), | ||
dtype == torch.bfloat16, | ||
), | ||
KeywordArg("permute_axes"), | ||
) | ||
dequantize_per_tensor_activation_pattern = CallFunction( | ||
torch.ops.torchao.dequantize_affine_float8.default, | ||
KeywordArg("x"), | ||
KeywordArg("x_scale"), | ||
output_dtype=KeywordArg("x_dq_dtype"), | ||
) | ||
|
||
dequant_fp8_linear_bias_pattern = _may_generate_pattern_with_reshape( | ||
CallFunction( | ||
aten.addmm.default, | ||
KeywordArg("b"), | ||
_may_generate_pattern_with_reshape( | ||
_may_generate_pattern_with_dtype_convert( | ||
dequantize_per_tensor_activation_pattern, | ||
KeywordArg("autocast_act_dtype"), | ||
dtype == torch.bfloat16, | ||
), | ||
KeywordArg("act_reshape_size"), | ||
input_dim_exceeds_two, | ||
), | ||
t_pattern, | ||
), | ||
KeywordArg("output_reshape_size"), | ||
input_dim_exceeds_two, | ||
) | ||
dequant_fp8_linear_no_bias_pattern = _may_generate_pattern_with_reshape( | ||
CallFunction( | ||
aten.mm.default, | ||
_may_generate_pattern_with_reshape( | ||
_may_generate_pattern_with_dtype_convert( | ||
dequantize_per_tensor_activation_pattern, | ||
KeywordArg("autocast_act_dtype"), | ||
dtype == torch.bfloat16, | ||
), | ||
KeywordArg("act_reshape_size"), | ||
input_dim_exceeds_two, | ||
), | ||
t_pattern, | ||
), | ||
KeywordArg("output_reshape_size"), | ||
input_dim_exceeds_two, | ||
) | ||
return dequant_fp8_linear_bias_pattern, dequant_fp8_linear_no_bias_pattern | ||
|
||
|
||
def _is_valid_scaled_mm_pattern(dtype, input_dim_exceeds_two): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The pattern is fp8 qlinear, not |
||
def _inner(match): | ||
input_contiguous = True | ||
# Check dequant pattern has only 1 user. | ||
( | ||
linear_node, | ||
_, | ||
) = _get_linear_node(match, input_dim_exceeds_two, input_contiguous) | ||
|
||
input_index = 1 if linear_node.target is aten.addmm.default else 0 | ||
assert dtype in [torch.float32, torch.bfloat16] | ||
( | ||
dequant_node, | ||
_, | ||
_, | ||
_, | ||
) = _get_linear_dq_node( | ||
linear_node, input_index, dtype, input_dim_exceeds_two, input_contiguous | ||
) | ||
assert dequant_node.target is torch.ops.torchao.dequantize_affine_float8.default | ||
|
||
# only support float8_e4m3 input | ||
if dequant_node.meta["eager_input_vals"][0][0].dtype != torch.float8_e4m3fn: | ||
return False | ||
|
||
if len(list(dequant_node.users)) != 1: | ||
# Ensure the dequant pattern only has 1 user | ||
# since we will delete the dequant pattern here | ||
return False | ||
|
||
return True | ||
|
||
return _inner | ||
|
||
|
||
def _register_scaled_mm_pass(pattern, dtype, input_dim_exceeds_two): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same here. scaled_mm -> fp8_qlinear. |
||
@register_freezing_graph_pattern( | ||
pattern, | ||
extra_check=_is_valid_scaled_mm_pattern(dtype, input_dim_exceeds_two), | ||
pass_number=0, | ||
) | ||
def scaled_mm_fusion(match: Match, *args, **kwargs): | ||
input_contiguous = True | ||
assert dtype in [torch.float32, torch.bfloat16] | ||
( | ||
linear_node, | ||
output_reshape_node, | ||
) = _get_linear_node(match, input_dim_exceeds_two, input_contiguous) | ||
input_index = 1 if linear_node.target is aten.addmm.default else 0 | ||
weight_index = input_index + 1 | ||
|
||
( | ||
dequant_node, | ||
act_reshape_node, | ||
activation_to_bf16_node, | ||
act_expand_node, | ||
) = _get_linear_dq_node( | ||
linear_node, input_index, dtype, input_dim_exceeds_two, input_contiguous | ||
) | ||
|
||
if input_dim_exceeds_two and not input_contiguous: | ||
wgt_expand_node = linear_node.args[weight_index] | ||
assert wgt_expand_node.target is aten.expand.default | ||
t_node = wgt_expand_node.args[0] | ||
else: | ||
t_node = linear_node.args[weight_index] | ||
|
||
if dtype == torch.float32: | ||
dequant_per_tensor = t_node.args[0] | ||
else: | ||
weight_to_bf16_node = t_node.args[0] | ||
dequant_per_tensor = weight_to_bf16_node.args[0] | ||
assert ( | ||
dequant_per_tensor.target | ||
is torch.ops.torchao.dequantize_affine_float8.default | ||
) | ||
|
||
# Activation QParams | ||
qx, x_scale = ( | ||
kwargs["x"], | ||
kwargs["x_scale"], | ||
) | ||
|
||
# Weight QParams | ||
qw, w_scale = ( | ||
kwargs["q_weight"], | ||
kwargs["w_scale"], | ||
) | ||
|
||
# Params | ||
bias = kwargs["b"] if "b" in kwargs else None | ||
|
||
x_shape = qx.meta.get("tensor_meta").shape | ||
if has_free_symbols(x_shape): | ||
# For dynamic shape case, we can't get activation shape ahead of runtime. | ||
x_shape = None | ||
graph = match.graph | ||
with graph.inserting_before(linear_node): | ||
scaled_mm_input_node = qx | ||
if input_dim_exceeds_two: | ||
new_reshape_args: tuple[Any, ...] = (qx, act_reshape_node.args[1]) | ||
new_act_reshape_node = graph.call_function( | ||
torch.ops.aten.reshape.default, args=new_reshape_args | ||
) | ||
scaled_mm_input_node = new_act_reshape_node | ||
# Insert weight prepack node and the qlinear node | ||
permute_weight_inputs = ( | ||
qw, | ||
t_node.args[1], | ||
) | ||
permute_weight_op = torch.ops.aten.permute.default | ||
permute_weight_node = graph.call_function( | ||
permute_weight_op, args=permute_weight_inputs | ||
) | ||
output_scale = torch.tensor(1.0) | ||
new_args: tuple[Any, ...] = ( | ||
scaled_mm_input_node, | ||
permute_weight_node, | ||
x_scale, | ||
w_scale, | ||
bias, | ||
output_scale, # output_scale | ||
dtype, # output_dtype | ||
False, # use_fast_accum | ||
) | ||
new_linear_node = graph.call_function( | ||
torch.ops.aten._scaled_mm.default, args=new_args | ||
) | ||
|
||
linear_node.replace_all_uses_with(new_linear_node) | ||
new_linear_node.meta.update(linear_node.meta) | ||
|
||
graph.erase_node(linear_node) | ||
if input_dim_exceeds_two: | ||
graph.erase_node(act_reshape_node) | ||
if dtype == torch.bfloat16: | ||
graph.erase_node(activation_to_bf16_node) | ||
# Erase the dequant pattern | ||
graph.erase_node(dequant_node) | ||
# Erase the dequant per channel pattern | ||
graph.erase_node(t_node) | ||
if dtype == torch.bfloat16: | ||
graph.erase_node(weight_to_bf16_node) # type: ignore[possibly-undefined] | ||
graph.erase_node(dequant_per_tensor) | ||
|
||
counters["inductor"]["scaled_mm_matcher_count"] += 1 | ||
counters["inductor"]["scaled_mm_matcher_nodes"] += len(match.nodes) | ||
|
||
|
||
def _register_scaled_mm(): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same here. scaled_mm -> fp8_qlinear. |
||
fp8_linear_weight_prepack_cases = itertools.product( | ||
[torch.float32, torch.bfloat16], [False, True] | ||
) | ||
for dtype, input_dim_exceeds_two in fp8_linear_weight_prepack_cases: | ||
patterns = _generate_dequant_fp8_linear_node_pattern( | ||
dtype, input_dim_exceeds_two | ||
) | ||
for pattern in patterns: | ||
_register_scaled_mm_pass(pattern, dtype, input_dim_exceeds_two) | ||
|
||
|
||
@functools.lru_cache(None) | ||
def _register_quantization_weight_pack_pass(): | ||
# Step 1: Dequant promotion for int8-mixed-fp32/bf16 | ||
|
@@ -2763,6 +2998,8 @@ def _register_quantization_weight_pack_pass(): | |
_register_qlinear_unary_fusion() | ||
_register_qlinear_binary_fusion() | ||
|
||
_register_scaled_mm() | ||
|
||
|
||
def quant_lift_up(module_graph: torch.fx.graph.Graph): | ||
""" | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I believe this is the training float8 test file, float8 inference is using https://github.com/pytorch/ao/blob/main/test/dtypes/test_affine_quantized_float.py
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok. I change the ut path on last pr #2379