Skip to content

Commit 3884e29

Browse files
authored
Patch the _is_conv_node function
Differential Revision: D74898941 Pull Request resolved: #2223
1 parent adc78b7 commit 3884e29

File tree

2 files changed

+23
-0
lines changed

2 files changed

+23
-0
lines changed

test/quantization/pt2e/test_quantize_pt2e.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2478,6 +2478,26 @@ def forward(self, x):
24782478
node_list,
24792479
)
24802480

2481+
example_inputs = (torch.randn(1, 3, 5, 5),)
2482+
node_occurrence = {
2483+
# two for input of the first conv, one for output for the first conv
2484+
torch.ops.quantized_decomposed.quantize_per_tensor.default: 2,
2485+
torch.ops.quantized_decomposed.dequantize_per_tensor.default: 3,
2486+
}
2487+
node_list = [
2488+
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
2489+
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
2490+
torch.ops.aten.conv2d.padding,
2491+
torch.ops.aten.relu.default,
2492+
torch.ops.quantized_decomposed.quantize_per_tensor.default,
2493+
]
2494+
self._test_quantizer(
2495+
TestHelperModules.ConvWithBNRelu(dim=2, relu=True, bn=True, padding="same"),
2496+
example_inputs,
2497+
BackendAQuantizer(),
2498+
node_occurrence,
2499+
node_list,
2500+
)
24812501
def test_conv_transpose3d_bn_relu(self):
24822502
class BackendAQuantizer(Quantizer):
24832503
def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:

torchao/quantization/pt2e/utils.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -625,8 +625,11 @@ def _is_conv_node(n: Node):
625625
"""
626626
return n.op == "call_function" and n.target in [
627627
torch.ops.aten.conv1d.default,
628+
torch.ops.aten.conv1d.padding,
628629
torch.ops.aten.conv2d.default,
630+
torch.ops.aten.conv2d.padding,
629631
torch.ops.aten.conv3d.default,
632+
torch.ops.aten.conv3d.padding,
630633
]
631634

632635

0 commit comments

Comments
 (0)