diff --git a/backends/xnnpack/quantizer/xnnpack_quantizer.py b/backends/xnnpack/quantizer/xnnpack_quantizer.py index 130eda03f88..c07d27e4231 100644 --- a/backends/xnnpack/quantizer/xnnpack_quantizer.py +++ b/backends/xnnpack/quantizer/xnnpack_quantizer.py @@ -274,7 +274,7 @@ class XNNPACKQuantizer(Quantizer): QuantPattern("linear_relu", False, False, LINEAR_TARGETS), QuantPattern("linear", True, False, LINEAR_TARGETS), QuantPattern("conv", True, False, CONV_TARGETS), - QuantPattern("conv_transpose", False, False, CONV_TARGETS), + QuantPattern("conv_transpose", True, False, CONV_TARGETS), QuantPattern("conv_relu", False, False, CONV_TARGETS), QuantPattern("conv_transpose_relu", False, False, CONV_TARGETS), QuantPattern("adaptive_avg_pool2d", False, False, ADAPTIVE_AVG_POOL2D_TARGETS), diff --git a/backends/xnnpack/quantizer/xnnpack_quantizer_utils.py b/backends/xnnpack/quantizer/xnnpack_quantizer_utils.py index 0dcfb4484ed..3d687d0b513 100644 --- a/backends/xnnpack/quantizer/xnnpack_quantizer_utils.py +++ b/backends/xnnpack/quantizer/xnnpack_quantizer_utils.py @@ -4,7 +4,10 @@ import torch import torch.nn.functional as F -from executorch.backends.xnnpack.utils.utils import is_depthwise_conv +from executorch.backends.xnnpack.utils.utils import ( + get_groups_from_conv, + is_depthwise_conv, +) from torch._subclasses import FakeTensor from torch.fx import Node from torch.fx.passes.utils.matcher_with_name_node_map_utils import ( @@ -65,6 +68,28 @@ def decorator(annotator: AnnotatorType) -> None: return decorator +def change_quantization_config( + original_qspec, + dtype=None, + quant_min=None, + quant_max=None, + qscheme=None, + ch_axis=None, + is_dynamic=None, + observer_or_fake_quant_ctr=None, +): + return QuantizationSpec( + dtype=dtype or original_qspec.dtype, + quant_min=quant_min or original_qspec.quant_min, + quant_max=quant_max or original_qspec.quant_max, + qscheme=qscheme or original_qspec.qscheme, + ch_axis=ch_axis or original_qspec.ch_axis, + is_dynamic=is_dynamic or original_qspec.is_dynamic, + observer_or_fake_quant_ctr=observer_or_fake_quant_ctr + or original_qspec.observer_or_fake_quant_ctr, + ) + + def is_relu_node(node: Node) -> bool: """ Check if a given node is a relu node @@ -231,6 +256,9 @@ def _do_annotate_conv( if is_relu_node(user): continue + # Tracks conditions for whether or not to skip + skip = False + input_qspec_map = {} input_act = conv_node.args[0] assert isinstance(input_act, Node) @@ -238,24 +266,34 @@ def _do_annotate_conv( weight = conv_node.args[1] assert isinstance(weight, Node) - input_qspec_map[weight] = get_weight_qspec(quantization_config) + weight_qspec = get_weight_qspec(quantization_config) + num_groups = get_groups_from_conv(conv_node) - # Only annotate dynamically quantized conv if it's 2D and not depthwise - if ( + # skip if transposed conv has more than 1 group + skip = skip or (is_conv_transpose and num_groups != 1) + print(f"{skip} conv transpose and num_groups") + + if is_conv_transpose: + # transposed convs per output channel quantization + weight_qspec = change_quantization_config(weight_qspec, ch_axis=1) + + input_qspec_map[weight] = weight_qspec + is_dynamic = ( quantization_config and quantization_config.input_activation and quantization_config.input_activation.is_dynamic - ): + ) + + # Only annotate dynamically quantized conv if it's 2D and not depthwise + if is_dynamic: weight_val = weight.meta.get("val", None) weight_shape = getattr(weight_val, "shape", None) - # Skip if not a 4D weight tensor (i.e. not conv2d) - if weight_shape is not None and len(weight_shape) != 4: - continue - + skip = skip or (weight_shape is not None and len(weight_shape) != 4) # Skip if depthwise (default to groups=1 since it's not an arg) - if is_depthwise_conv(weight_shape, 1, is_conv_transpose): - continue + skip = skip or ( + not is_conv_transpose and is_depthwise_conv(weight_shape, 1, False) + ) # adding weight node to the partition as well partition = [conv_node, conv_node.args[1]] @@ -265,7 +303,7 @@ def _do_annotate_conv( input_qspec_map[bias] = get_bias_qspec(quantization_config) partition.append(bias) - if _is_annotated(partition): + if _is_annotated(partition) or skip: continue if filter_fn and any(not filter_fn(n) for n in partition): @@ -311,7 +349,12 @@ def _do_annotate_conv_relu( weight = conv_node.args[1] assert isinstance(weight, Node) - input_qspec_map[weight] = get_weight_qspec(quantization_config) + weight_qspec = get_weight_qspec(quantization_config) + groups = get_groups_from_conv(conv_node) + if is_conv_transpose: + # transposed convs per output channel quantization + weight_qspec = change_quantization_config(weight_qspec, ch_axis=1) + input_qspec_map[weight] = weight_qspec # adding weight node to the partition as well partition = [relu_node, conv_node, conv_node.args[1]] @@ -323,6 +366,9 @@ def _do_annotate_conv_relu( if _is_annotated(partition): continue + if is_conv_transpose and groups != 1: + continue + if filter_fn and any(not filter_fn(n) for n in partition): continue diff --git a/backends/xnnpack/test/ops/test_conv2d.py b/backends/xnnpack/test/ops/test_conv2d.py index 92bb03c907a..2a0a82d99b6 100644 --- a/backends/xnnpack/test/ops/test_conv2d.py +++ b/backends/xnnpack/test/ops/test_conv2d.py @@ -174,14 +174,11 @@ def get_inputs(self): class Conv2dDQSeq(torch.nn.Module): - def __init__(self): + def __init__(self, transpose=False): super().__init__() - self.first = torch.nn.Conv2d( - in_channels=3, out_channels=8, kernel_size=3, padding=1 - ) - self.second = torch.nn.Conv2d( - in_channels=8, out_channels=10, kernel_size=3, padding=1 - ) + op = torch.nn.ConvTranspose2d if transpose else torch.nn.Conv2d + self.first = op(in_channels=3, out_channels=8, kernel_size=3, padding=1) + self.second = op(in_channels=8, out_channels=10, kernel_size=3, padding=1) def forward(self, x): y = self.first(x) @@ -192,14 +189,11 @@ def get_inputs(self): class Conv2dDQParallel(torch.nn.Module): - def __init__(self): + def __init__(self, transpose=False): super().__init__() - self.first = torch.nn.Conv2d( - in_channels=3, out_channels=8, kernel_size=3, padding=1 - ) - self.second = torch.nn.Conv2d( - in_channels=3, out_channels=8, kernel_size=3, padding=1 - ) + op = torch.nn.ConvTranspose2d if transpose else torch.nn.Conv2d + self.first = op(in_channels=3, out_channels=8, kernel_size=3, padding=1) + self.second = op(in_channels=3, out_channels=10, kernel_size=3, padding=1) def forward(self, x): first = self.first(x) @@ -221,7 +215,6 @@ def _test( conv_count=1, dtype: torch.dtype = torch.float, check_quantized=True, - delegated=True, ): # pyre-fixme[29]: `Union[torch._tensor.Tensor, # torch.nn.modules.module.Module]` is not a function. @@ -240,29 +233,20 @@ def _test( (tester.export().check_count({op: conv_count}).to_edge_transform_and_lower()) - if delegated: - ( - tester.check_not( - ["executorch_exir_dialects_edge__ops_aten_convolution_default"] - ) - .check_not( - [ - "executorch_exir_dialects_edge__ops__native_batch_norm_legit_no_training_default" - ] - ) - .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) - .to_executorch() - .serialize() - .run_method_and_compare_outputs(qtol=1) + ( + tester.check_not( + ["executorch_exir_dialects_edge__ops_aten_convolution_default"] ) - else: - # need quantize ops when ops are not delegated to xnnpack - if has_quantized_ops: - ( - tester.to_executorch() - .serialize() - .run_method_and_compare_outputs(qtol=1) - ) + .check_not( + [ + "executorch_exir_dialects_edge__ops__native_batch_norm_legit_no_training_default" + ] + ) + .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) + .to_executorch() + .serialize() + .run_method_and_compare_outputs(qtol=1) + ) def _test_dq( self, @@ -276,8 +260,7 @@ def _test_dq( ) DynamicallyQuantizedPartitioner = XnnpackPartitioner( - config_precisions=ConfigPrecisionType.DYNAMIC_QUANT, - per_op_mode=True, + config_precisions=ConfigPrecisionType.DYNAMIC_QUANT, per_op_mode=True ) tester = Tester(m, m.get_inputs(), dynamic_shapes=dynamic_shapes) @@ -325,7 +308,6 @@ def test_qs8_conv2d_per_channel(self) -> None: self._test( Conv2d(transpose=transpose), quant_config=get_symmetric_quantization_config(is_per_channel=True), - delegated=not transpose, # XNNPACK does not support per input channel quantization for transpose convolutions with groups > 1 ) def test_fp32_conv2d_seq(self) -> None: @@ -360,11 +342,10 @@ def test_fp32_conv2d_depthwise(self): ) def test_qs8_conv2d_depthwise(self): - for transpose in (True, False): - self._test( - Conv2d(groups=2, in_channels=2, out_channels=6, transpose=transpose), - quant_config=get_symmetric_quantization_config(), - ) + self._test( + Conv2d(groups=2, in_channels=2, out_channels=6), + quant_config=get_symmetric_quantization_config(), + ) def test_fp32_conv2d_bn(self): class Conv2dBatchNorm(torch.nn.Module): @@ -485,7 +466,6 @@ def get_inputs(self): self._test( ConvReLU(transpose=transpose), quant_config=get_symmetric_quantization_config(is_per_channel=True), - delegated=not transpose, # XNNPACK does not support per input channel quantization for transpose convolutions with groups > 1 ) def test_qs8_conv2d_dw_relu(self): @@ -527,19 +507,14 @@ def forward(self, x): def get_inputs(self): return (torch.randn(batches, in_channels, height, width) * 11,) - for transpose in (True, False): - for per_channel_quant in (False, True): - if transpose and per_channel_quant: - continue - model = ModelConvReLU(transpose=transpose) - self._test( - model, - quant_config=get_symmetric_quantization_config( - is_per_channel=per_channel_quant - ), - # XNNPACK does not support per input channel quantization for transpose convolutions with groups > 1 - delegated=not (transpose and per_channel_quant), - ) + for per_channel_quant in (False, True): + model = ModelConvReLU() + self._test( + model, + quant_config=get_symmetric_quantization_config( + is_per_channel=per_channel_quant + ), + ) def test_qs8_conv2d_relu_seq(self): class ConvReLUSeq(torch.nn.Module): @@ -593,7 +568,7 @@ def get_inputs(self): conv_count=2, ) - def test_qs8_conv_transpose_2d_quantize_per_channel(self): + def test_qs8_conv_transpose_2d_quantize_per_channel_multi_axis(self): class PerChannelConvTranspose2d(torch.nn.Module): def __init__(self, input_channels, output_channels, groups, axis): super().__init__() @@ -662,76 +637,24 @@ def get_inputs(self): ) for groups in (1, 2): - for axis in (0, 1): - self._test( - PerChannelConvTranspose2d(3 * groups, 5 * groups, groups, axis), - quant_config=None, - conv_count=1, - delegated=axis == 1 - and groups - == 1, # xnnpack only support output channel axis quantization with groups == 1 - ) - - def test_qs8_conv_transpose_2d_dqd_f32_weights(self): - class TransposeConv2dDQDf32weights(torch.nn.Module): - def __init__(self, input_channels, output_channels, groups, axis): - super().__init__() - self.input_channels = input_channels - self.output_channels = output_channels - self.axis = axis - self.groups = groups - self.transpose = True - self.weights = torch.nn.Parameter( - torch.randn((input_channels, output_channels // groups, 4, 4)), - requires_grad=False, - ) - - axis_size = self.weights.shape[axis] - self.scale = torch.nn.Parameter(torch.ones(axis_size) * 0.12345) - self.zero_point = torch.nn.Parameter( - torch.zeros((axis_size,), dtype=torch.int64), requires_grad=False - ) - - def forward(self, x): - dequantize_input = ( - exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default( - x, 0.12345, 0, -127, 127, torch.int8 + for ch_axis in (1, 2): + if ch_axis == 1 and groups == 1: + self._test( + PerChannelConvTranspose2d( + 3 * groups, 5 * groups, groups, ch_axis + ), # ch_axis=0 + quant_config=None, + conv_count=1, ) - ) - x = torch.nn.functional.conv_transpose2d( - dequantize_input, self.weights, groups=self.groups - ) - - return exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default( - exir_ops.edge.quantized_decomposed.quantize_per_tensor.default( - x, - 0.12345, - 0, - -127, - 127, - torch.int8, - ), - 0.12345, - 0, - -127, - 127, - torch.int8, - ) - - def get_inputs(self): - return ( - torch.randint( - low=-127, high=127, size=(3, self.input_channels, 4, 4) - ).type(dtype=torch.int8), - ) - - for groups in (1, 2): - for axis in (0, 1): - self._test( - TransposeConv2dDQDf32weights(3 * groups, 5 * groups, groups, axis), - quant_config=None, - conv_count=1, - ) + else: + with self.assertRaises(RuntimeError): + self._test( + PerChannelConvTranspose2d( + 3 * groups, 5 * groups, groups, ch_axis + ), # ch_axis=0 + quant_config=None, + conv_count=1, + ) def test_padded_output_tconv(self): class TConv2d(torch.nn.Module): @@ -761,7 +684,7 @@ def forward(self, x): (tester.export().check_count({op: conv_count}).to_edge_transform_and_lower()) - # tconv should not be offloaded to XNNPack, since output padding is not + # tconv should not be offloaded to XNNPack, since output padding is not supported ( tester.check( ["executorch_exir_dialects_edge__ops_aten_convolution_default"] @@ -794,3 +717,31 @@ def test_dq_conv2d_parallel(self) -> None: model = Conv2dDQParallel() conv_count = sum(1 for m in model.modules() if type(m) is torch.nn.Conv2d) self._test_dq(model, conv_count) + + def test_dq_conv2d_transpose(self) -> None: + model = Conv2d( + in_channels=3, + out_channels=10, + kernel_size=(3, 3), + stride=(1, 1), + padding=(0, 0), + batches=1, + width=8, + height=8, + transpose=True, + ) + self._test_dq(model) + + def test_dq_conv2d_transpose_seq(self) -> None: + model = Conv2dDQSeq(transpose=True) + conv_count = sum( + 1 for m in model.modules() if type(m) is torch.nn.ConvTranspose2d + ) + self._test_dq(model, conv_count) + + def test_dq_conv2d_transpose_parallel(self) -> None: + model = Conv2dDQParallel(transpose=True) + conv_count = sum( + 1 for m in model.modules() if type(m) is torch.nn.ConvTranspose2d + ) + self._test_dq(model, conv_count) diff --git a/backends/xnnpack/utils/utils.py b/backends/xnnpack/utils/utils.py index b23fd444117..a8f3178f98f 100644 --- a/backends/xnnpack/utils/utils.py +++ b/backends/xnnpack/utils/utils.py @@ -25,6 +25,7 @@ is_lifted_tensor_constant, is_param, ) +from torchao.quantization.pt2e.utils import _is_conv_node, _is_conv_transpose_node ### XNNPACK Capture ### @@ -160,6 +161,36 @@ def get_source_fn(node: torch.fx.Node) -> Optional[torch.fx.Node]: return source_fn[1] +def get_groups_from_conv(conv_node: torch.fx.Node) -> int: + if _is_conv_node(conv_node): + in_node = cast(torch.fx.Node, conv_node.args[0]) + weight_node = cast(torch.fx.Node, conv_node.args[1]) + # groups isn't given to us in the training graph so we deduce it from the weight shape + # and the input shape + + # input shape is (N, C_in, H_in, W_in) + in_channels = in_node.meta["val"].shape[1] + + # weight shape is (C_out, C_in/groups, kernel_size[0], kernel_size[1]) + in_groups = weight_node.meta["val"].shape[1] + + return in_channels // in_groups + elif _is_conv_transpose_node(conv_node): + weight_node = cast(torch.fx.Node, conv_node.args[1]) + # groups isn't given to us in the training graph so we deduce it from the weight shape + # and the output shape + + # weight shape is (C_in, C_out/groups, kernel_size[0], kernel_size[1]) + out_groups = weight_node.meta["val"].shape[1] + + # output shape is (N, C_out, H_out, W_out) + out_channels = conv_node.meta["val"].shape[1] + + return out_channels // out_groups + + raise RuntimeError(f"expected {conv_node} to be a conv or conv_transpose node") + + def is_depthwise_conv( kernel_shape: Tuple[int, ...], groups: int = 1, is_transpose: bool = False ) -> bool: