Skip to content

Commit d863085

Browse files
added the testcases
1 parent 2766765 commit d863085

File tree

1 file changed

+273
-0
lines changed

1 file changed

+273
-0
lines changed
Lines changed: 273 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,273 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# BSD-3-Clause
3+
4+
# Owner(s): ["oncall: quantization"]
5+
6+
import functools
7+
import platform
8+
import unittest
9+
from typing import Dict
10+
11+
import torch
12+
import torch.nn as nn
13+
import torchao
14+
import torchao.quantization.pt2e.quantizer.arm_inductor_quantizer as armiq
15+
from torchao.quantization.pt2e.quantizer.arm_inductor_quantizer import (
16+
ArmInductorQuantizer,
17+
)
18+
from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e
19+
from torchao.quantization.pt2e.inductor_passes.arm import (
20+
_register_quantization_weight_pack_pass,
21+
)
22+
23+
from torch.testing._internal.common_quantization import (
24+
NodeSpec as ns,
25+
QuantizationTestCase,
26+
skipIfNoInductorSupport,
27+
)
28+
from torch.testing._internal.common_utils import run_tests, skipIfTorchDynamo
29+
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, TORCH_VERSION_AT_LEAST_2_7
30+
31+
# ----------------------------------------------------------------------------- #
32+
# Helper decorators #
33+
# ----------------------------------------------------------------------------- #
34+
def skipIfNoArm(fn):
35+
reason = "Quantized operations require Arm."
36+
if isinstance(fn, type):
37+
if platform.processor() != "aarch64":
38+
fn.__unittest_skip__ = True
39+
fn.__unittest_skip_why__ = reason
40+
return fn
41+
42+
@functools.wraps(fn)
43+
def wrapper(*args, **kwargs):
44+
if platform.processor() != "aarch64":
45+
raise unittest.SkipTest(reason)
46+
return fn(*args, **kwargs)
47+
48+
return wrapper
49+
50+
51+
# ----------------------------------------------------------------------------- #
52+
# Mini-models #
53+
# ----------------------------------------------------------------------------- #
54+
class _SingleConv2d(nn.Module):
55+
def __init__(self):
56+
super().__init__()
57+
self.conv = nn.Conv2d(3, 6, kernel_size=3, stride=1, padding=1)
58+
59+
def forward(self, x):
60+
return self.conv(x)
61+
62+
63+
class _SingleLinear(nn.Module):
64+
def __init__(self, bias: bool = False):
65+
super().__init__()
66+
self.linear = nn.Linear(16, 16, bias=bias)
67+
68+
def forward(self, x):
69+
return self.linear(x)
70+
71+
72+
if TORCH_VERSION_AT_LEAST_2_5:
73+
from torch.export import export_for_training
74+
75+
76+
# ----------------------------------------------------------------------------- #
77+
# Base harness #
78+
# ----------------------------------------------------------------------------- #
79+
class _ArmInductorPerTensorTestCase(QuantizationTestCase):
80+
def _test_quantizer(
81+
self,
82+
model: torch.nn.Module,
83+
example_inputs: tuple[torch.Tensor, ...],
84+
quantizer: ArmInductorQuantizer,
85+
expected_node_occurrence: Dict[torch._ops.OpOverload, int],
86+
expected_node_list=None,
87+
*,
88+
is_qat: bool = False,
89+
lower: bool = False,
90+
):
91+
gm = export_for_training(model.eval(), example_inputs).module()
92+
93+
gm = prepare_pt2e(gm, quantizer)
94+
gm(*example_inputs)
95+
gm = convert_pt2e(gm)
96+
97+
if lower:
98+
# Register weight-pack pass (only affects per-tensor path; harmless otherwise)
99+
_register_quantization_weight_pack_pass(per_channel=False)
100+
from torch._inductor.constant_folding import constant_fold
101+
from torch._inductor.fx_passes.freezing_patterns import freezing_passes
102+
103+
gm.recompile()
104+
freezing_passes(gm, example_inputs)
105+
constant_fold(gm)
106+
gm(*example_inputs)
107+
108+
self.checkGraphModuleNodes(
109+
gm,
110+
expected_node_occurrence={
111+
ns.call_function(k): v for k, v in expected_node_occurrence.items()
112+
},
113+
expected_node_list=[
114+
ns.call_function(n) for n in (expected_node_list or [])
115+
],
116+
)
117+
118+
119+
# ----------------------------------------------------------------------------- #
120+
# Test-suite #
121+
# ----------------------------------------------------------------------------- #
122+
@skipIfNoInductorSupport
123+
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_7, "Requires torch 2.7+")
124+
class TestQuantizePT2EArmInductorPerTensor(_ArmInductorPerTensorTestCase):
125+
# ------------------------------------------------------------------ #
126+
# 1. Conv2d - per-tensor static PTQ #
127+
# ------------------------------------------------------------------ #
128+
@skipIfNoArm
129+
def test_conv2d_per_tensor_weight(self):
130+
example_inputs = (torch.randn(2, 3, 16, 16),)
131+
q = ArmInductorQuantizer().set_global(
132+
armiq.get_default_arm_inductor_quantization_config(is_per_channel=False)
133+
)
134+
expected = {
135+
torch.ops.quantized_decomposed.quantize_per_tensor.default: 1,
136+
torch.ops.quantized_decomposed.dequantize_per_tensor.default: 2,
137+
torch.ops.quantized_decomposed.dequantize_per_channel.default: 0,
138+
}
139+
self._test_quantizer(
140+
_SingleConv2d(), example_inputs, q, expected, lower=True
141+
)
142+
143+
# ------------------------------------------------------------------ #
144+
# 2. Linear - per-tensor static PTQ #
145+
# ------------------------------------------------------------------ #
146+
@skipIfNoArm
147+
def test_linear_per_tensor_weight(self):
148+
example_inputs = (torch.randn(4, 16),)
149+
q = ArmInductorQuantizer().set_global(
150+
armiq.get_default_arm_inductor_quantization_config(is_per_channel=False)
151+
)
152+
expected = {
153+
torch.ops.quantized_decomposed.quantize_per_tensor.default: 1,
154+
torch.ops.quantized_decomposed.dequantize_per_tensor.default: 2,
155+
torch.ops.quantized_decomposed.dequantize_per_channel.default: 0,
156+
}
157+
self._test_quantizer(
158+
_SingleLinear(), example_inputs, q, expected, lower=True
159+
)
160+
161+
# ------------------------------------------------------------------ #
162+
# 3. Linear - per-tensor **dynamic** #
163+
# ------------------------------------------------------------------ #
164+
@skipIfNoArm
165+
def test_linear_dynamic_per_tensor_weight(self):
166+
example_inputs = (torch.randn(8, 16),)
167+
q = ArmInductorQuantizer().set_global(
168+
armiq.get_default_arm_inductor_quantization_config(
169+
is_dynamic=True, is_per_channel=False
170+
)
171+
)
172+
expected = {
173+
torch.ops.quantized_decomposed.choose_qparams.tensor: 1,
174+
torch.ops.quantized_decomposed.quantize_per_tensor.tensor: 1,
175+
torch.ops.quantized_decomposed.dequantize_per_tensor.tensor: 1,
176+
torch.ops.quantized_decomposed.dequantize_per_tensor.default: 1,
177+
torch.ops.quantized_decomposed.dequantize_per_channel.default: 0,
178+
}
179+
self._test_quantizer(
180+
_SingleLinear(), example_inputs, q, expected, lower=True
181+
)
182+
183+
# ------------------------------------------------------------------ #
184+
# 4. Conv2d - **per-channel** static PTQ #
185+
# ------------------------------------------------------------------ #
186+
@skipIfNoArm
187+
def test_conv2d_per_channel_weight(self):
188+
example_inputs = (torch.randn(2, 3, 16, 16),)
189+
q = ArmInductorQuantizer().set_global(
190+
armiq.get_default_arm_inductor_quantization_config(is_per_channel=True)
191+
)
192+
expected = {
193+
torch.ops.quantized_decomposed.quantize_per_tensor.default: 1,
194+
torch.ops.quantized_decomposed.dequantize_per_tensor.default: 1,
195+
torch.ops.quantized_decomposed.dequantize_per_channel.default: 1,
196+
}
197+
self._test_quantizer(
198+
_SingleConv2d(), example_inputs, q, expected, lower=True
199+
)
200+
201+
# ------------------------------------------------------------------ #
202+
# 5. Linear - **per-channel** static PTQ #
203+
# ------------------------------------------------------------------ #
204+
@skipIfNoArm
205+
def test_linear_per_channel_weight(self):
206+
example_inputs = (torch.randn(4, 16),)
207+
q = ArmInductorQuantizer().set_global(
208+
armiq.get_default_arm_inductor_quantization_config(is_per_channel=True)
209+
)
210+
expected = {
211+
torch.ops.quantized_decomposed.quantize_per_tensor.default: 1,
212+
torch.ops.quantized_decomposed.dequantize_per_tensor.default: 1,
213+
torch.ops.quantized_decomposed.dequantize_per_channel.default: 1,
214+
}
215+
self._test_quantizer(
216+
_SingleLinear(), example_inputs, q, expected, lower=True
217+
)
218+
219+
# ------------------------------------------------------------------ #
220+
# 6. Conv2d - **QAT** per-tensor #
221+
# ------------------------------------------------------------------ #
222+
@skipIfTorchDynamo("slow under Dynamo")
223+
@skipIfNoArm
224+
def test_conv2d_qat_per_tensor_weight(self):
225+
example_inputs = (torch.randn(2, 3, 16, 16),)
226+
q = ArmInductorQuantizer().set_global(
227+
armiq.get_default_arm_inductor_quantization_config(is_qat=True)
228+
)
229+
expected = {
230+
torch.ops.quantized_decomposed.quantize_per_tensor.default: 1,
231+
torch.ops.quantized_decomposed.dequantize_per_tensor.default: 2,
232+
torch.ops.quantized_decomposed.dequantize_per_channel.default: 0,
233+
}
234+
self._test_quantizer(
235+
_SingleConv2d(),
236+
example_inputs,
237+
q,
238+
expected,
239+
is_qat=True,
240+
lower=True,
241+
)
242+
243+
# ------------------------------------------------------------------ #
244+
# 7. Linear - **dynamic + QAT** per-tensor #
245+
# ------------------------------------------------------------------ #
246+
@skipIfTorchDynamo("slow under Dynamo")
247+
@skipIfNoArm
248+
def test_linear_dynamic_qat_per_tensor_weight(self):
249+
example_inputs = (torch.randn(8, 16),)
250+
q = ArmInductorQuantizer().set_global(
251+
armiq.get_default_arm_inductor_quantization_config(
252+
is_dynamic=True, is_qat=True, is_per_channel=False
253+
)
254+
)
255+
expected = {
256+
torch.ops.quantized_decomposed.choose_qparams.tensor: 1,
257+
torch.ops.quantized_decomposed.quantize_per_tensor.tensor: 1,
258+
torch.ops.quantized_decomposed.dequantize_per_tensor.tensor: 1,
259+
torch.ops.quantized_decomposed.dequantize_per_tensor.default: 1,
260+
torch.ops.quantized_decomposed.dequantize_per_channel.default: 0,
261+
}
262+
self._test_quantizer(
263+
_SingleLinear(),
264+
example_inputs,
265+
q,
266+
expected,
267+
is_qat=True,
268+
lower=True,
269+
)
270+
271+
272+
if __name__ == "__main__":
273+
run_tests()

0 commit comments

Comments
 (0)