Skip to content

[Quantized DeConv Support] Enable Quantized Transposed Convs with groups==1 #11774

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jun 23, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion backends/xnnpack/quantizer/xnnpack_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,7 @@ class XNNPACKQuantizer(Quantizer):
QuantPattern("linear_relu", False, False, LINEAR_TARGETS),
QuantPattern("linear", True, False, LINEAR_TARGETS),
QuantPattern("conv", True, False, CONV_TARGETS),
QuantPattern("conv_transpose", False, False, CONV_TARGETS),
QuantPattern("conv_transpose", True, False, CONV_TARGETS),
QuantPattern("conv_relu", False, False, CONV_TARGETS),
QuantPattern("conv_transpose_relu", False, False, CONV_TARGETS),
QuantPattern("adaptive_avg_pool2d", False, False, ADAPTIVE_AVG_POOL2D_TARGETS),
Expand Down
72 changes: 59 additions & 13 deletions backends/xnnpack/quantizer/xnnpack_quantizer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,10 @@

import torch
import torch.nn.functional as F
from executorch.backends.xnnpack.utils.utils import is_depthwise_conv
from executorch.backends.xnnpack.utils.utils import (
get_groups_from_conv,
is_depthwise_conv,
)
from torch._subclasses import FakeTensor
from torch.fx import Node
from torch.fx.passes.utils.matcher_with_name_node_map_utils import (
Expand Down Expand Up @@ -65,6 +68,28 @@ def decorator(annotator: AnnotatorType) -> None:
return decorator


def change_quantization_config(
original_qspec,
dtype=None,
quant_min=None,
quant_max=None,
qscheme=None,
ch_axis=None,
is_dynamic=None,
observer_or_fake_quant_ctr=None,
):
return QuantizationSpec(
dtype=dtype or original_qspec.dtype,
quant_min=quant_min or original_qspec.quant_min,
quant_max=quant_max or original_qspec.quant_max,
qscheme=qscheme or original_qspec.qscheme,
ch_axis=ch_axis or original_qspec.ch_axis,
is_dynamic=is_dynamic or original_qspec.is_dynamic,
observer_or_fake_quant_ctr=observer_or_fake_quant_ctr
or original_qspec.observer_or_fake_quant_ctr,
)


def is_relu_node(node: Node) -> bool:
"""
Check if a given node is a relu node
Expand Down Expand Up @@ -231,31 +256,44 @@ def _do_annotate_conv(
if is_relu_node(user):
continue

# Tracks conditions for whether or not to skip
skip = False

input_qspec_map = {}
input_act = conv_node.args[0]
assert isinstance(input_act, Node)
input_qspec_map[input_act] = get_input_act_qspec(quantization_config)

weight = conv_node.args[1]
assert isinstance(weight, Node)
input_qspec_map[weight] = get_weight_qspec(quantization_config)
weight_qspec = get_weight_qspec(quantization_config)
num_groups = get_groups_from_conv(conv_node)

# Only annotate dynamically quantized conv if it's 2D and not depthwise
if (
# skip if transposed conv has more than 1 group
skip = skip or (is_conv_transpose and num_groups != 1)
print(f"{skip} conv transpose and num_groups")

if is_conv_transpose:
# transposed convs per output channel quantization
weight_qspec = change_quantization_config(weight_qspec, ch_axis=1)

input_qspec_map[weight] = weight_qspec
is_dynamic = (
quantization_config
and quantization_config.input_activation
and quantization_config.input_activation.is_dynamic
):
)

# Only annotate dynamically quantized conv if it's 2D and not depthwise
if is_dynamic:
weight_val = weight.meta.get("val", None)
weight_shape = getattr(weight_val, "shape", None)

# Skip if not a 4D weight tensor (i.e. not conv2d)
if weight_shape is not None and len(weight_shape) != 4:
continue

skip = skip or (weight_shape is not None and len(weight_shape) != 4)
# Skip if depthwise (default to groups=1 since it's not an arg)
if is_depthwise_conv(weight_shape, 1, is_conv_transpose):
continue
skip = skip or (
not is_conv_transpose and is_depthwise_conv(weight_shape, 1, False)
)

# adding weight node to the partition as well
partition = [conv_node, conv_node.args[1]]
Expand All @@ -265,7 +303,7 @@ def _do_annotate_conv(
input_qspec_map[bias] = get_bias_qspec(quantization_config)
partition.append(bias)

if _is_annotated(partition):
if _is_annotated(partition) or skip:
continue

if filter_fn and any(not filter_fn(n) for n in partition):
Expand Down Expand Up @@ -311,7 +349,12 @@ def _do_annotate_conv_relu(

weight = conv_node.args[1]
assert isinstance(weight, Node)
input_qspec_map[weight] = get_weight_qspec(quantization_config)
weight_qspec = get_weight_qspec(quantization_config)
groups = get_groups_from_conv(conv_node)
if is_conv_transpose:
# transposed convs per output channel quantization
weight_qspec = change_quantization_config(weight_qspec, ch_axis=1)
input_qspec_map[weight] = weight_qspec

# adding weight node to the partition as well
partition = [relu_node, conv_node, conv_node.args[1]]
Expand All @@ -323,6 +366,9 @@ def _do_annotate_conv_relu(
if _is_annotated(partition):
continue

if is_conv_transpose and groups != 1:
continue

if filter_fn and any(not filter_fn(n) for n in partition):
continue

Expand Down
Loading
Loading