Skip to content

Commit 280db15

Browse files
authored
Fix quantizer tests after dq conv is enabled
Differential Revision: D73898719 Pull Request resolved: #10569
1 parent c47f8e4 commit 280db15

File tree

3 files changed

+84
-85
lines changed

3 files changed

+84
-85
lines changed

backends/xnnpack/quantizer/xnnpack_quantizer.py

+71-80
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@
33

44
import copy
55
import functools
6-
from typing import Any, Callable, Optional, TYPE_CHECKING
6+
from dataclasses import dataclass
7+
from typing import Any, Callable, Optional, Set, TYPE_CHECKING
78

89
import torch
910
import torch._dynamo as torchdynamo
@@ -235,37 +236,52 @@ def not_module_type_or_name_filter(n: Node) -> bool:
235236
return not_module_type_or_name_filter
236237

237238

238-
class XNNPACKQuantizer(Quantizer):
239-
supported_config_and_operators = _get_supported_config_and_operators()
240-
STATIC_QAT_ONLY_OPS = [
241-
"conv_bn_relu",
242-
"conv_bn",
243-
"conv_transpose_bn_relu",
244-
"conv_transpose_bn",
245-
]
239+
@dataclass
240+
class QuantPattern:
241+
name: str
242+
is_dynamic: bool
243+
is_qat: bool
244+
op_overloads: Set[torch._ops.OpOverloadPacket]
245+
246+
247+
CONV_TARGETS = {
248+
torch.ops.aten.conv2d.default,
249+
torch.ops.aten.conv1d.default,
250+
torch.ops.aten.convolution.default,
251+
}
252+
253+
LINEAR_TARGETS = {
254+
torch.ops.aten.linear.default,
255+
}
256+
257+
ADAPTIVE_AVG_POOL2D_TARGETS = {torch.ops.aten.adaptive_avg_pool2d.default}
258+
259+
ADD_TARGETS = {torch.ops.aten.add.Tensor}
260+
261+
MUL_TARGETS = {torch.ops.aten.mul.Tensor}
262+
263+
CAT_TARGETS = {torch.ops.aten.cat.default}
246264

247-
# static quantization ops (both PTQ and QAT)
248-
# Preserve the order that fusions come before singular ops
249-
STATIC_OPS = [
250-
"linear_relu",
251-
"linear",
252-
"conv",
253-
"conv_transpose",
254-
"conv_relu",
255-
"conv_transpose_relu",
256-
"adaptive_avg_pool2d",
257-
# TODO: move this to BoltNNQuantizer?
258-
"gru_io_only",
259-
"add_relu",
260-
"add",
261-
"mul_relu",
262-
"mul",
263-
"cat",
264-
]
265265

266-
DYNAMIC_OPS = [
267-
"linear",
268-
"conv",
266+
class XNNPACKQuantizer(Quantizer):
267+
supported_config_and_operators = _get_supported_config_and_operators()
268+
SUPPORTED_PATTERNS = [
269+
QuantPattern("conv_bn_relu", False, True, CONV_TARGETS),
270+
QuantPattern("conv_bn", False, True, CONV_TARGETS),
271+
QuantPattern("conv_transpose_bn_relu", False, True, CONV_TARGETS),
272+
QuantPattern("conv_transpose_bn", False, True, CONV_TARGETS),
273+
QuantPattern("linear_relu", False, False, LINEAR_TARGETS),
274+
QuantPattern("linear", True, False, LINEAR_TARGETS),
275+
QuantPattern("conv", True, False, CONV_TARGETS),
276+
QuantPattern("conv_transpose", False, False, CONV_TARGETS),
277+
QuantPattern("conv_relu", False, False, CONV_TARGETS),
278+
QuantPattern("conv_transpose_relu", False, False, CONV_TARGETS),
279+
QuantPattern("adaptive_avg_pool2d", False, False, ADAPTIVE_AVG_POOL2D_TARGETS),
280+
QuantPattern("add_relu", False, False, ADD_TARGETS),
281+
QuantPattern("add", False, False, ADD_TARGETS),
282+
QuantPattern("mul_relu", False, False, MUL_TARGETS),
283+
QuantPattern("mul", False, False, MUL_TARGETS),
284+
QuantPattern("cat", False, False, CAT_TARGETS),
269285
]
270286

271287
def __init__(self) -> None:
@@ -347,83 +363,58 @@ def transform_for_annotation(
347363

348364
def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
349365
"""just handling global spec for now"""
350-
# hacked for handling dynamic linear quant. will fix later.
351-
if self.global_config and self.global_config.input_activation.is_dynamic: # type: ignore[union-attr]
352-
model = self._annotate_for_dynamic_quantization_config(model)
353-
else:
354-
model = self._annotate_for_static_quantization_config(model)
366+
model = self._annotate_for_quantization_config(model)
355367
propagate_annotation(model)
356368
return model
357369

358-
def _annotate_all_static_patterns(
370+
def _annotate_all_patterns(
359371
self,
360372
model: torch.fx.GraphModule,
361373
quantization_config: Optional[QuantizationConfig],
362374
filter_fn: Optional[Callable[[Node], bool]] = None,
363-
) -> torch.fx.GraphModule:
375+
operator_target: Optional[torch._ops.OpOverloadPacket] = None,
376+
):
364377
# TODO: implement the support for None to be canceling out previous annotations
365378
if quantization_config is None:
366379
return model
367380

368-
if quantization_config.is_qat:
369-
for op in self.STATIC_QAT_ONLY_OPS:
370-
OP_TO_ANNOTATOR[op](model, quantization_config, filter_fn)
371-
for op in self.STATIC_OPS:
372-
OP_TO_ANNOTATOR[op](model, quantization_config, filter_fn)
373-
return model
381+
for pattern in self.SUPPORTED_PATTERNS:
382+
if operator_target and operator_target not in pattern.op_overloads:
383+
# if operator_target is specified, skip patterns that aren't
384+
# associated with that target
385+
continue
386+
if quantization_config.input_activation.is_dynamic and pattern.is_dynamic:
387+
OP_TO_ANNOTATOR[pattern.name](model, quantization_config, filter_fn)
388+
elif quantization_config.is_qat and pattern.is_qat:
389+
OP_TO_ANNOTATOR[pattern.name](model, quantization_config, filter_fn)
390+
elif not quantization_config.input_activation.is_dynamic:
391+
OP_TO_ANNOTATOR[pattern.name](model, quantization_config, filter_fn)
374392

375-
def _annotate_all_dynamic_patterns(
376-
self,
377-
model: torch.fx.GraphModule,
378-
quantization_config: Optional[QuantizationConfig],
379-
filter_fn: Optional[Callable[[Node], bool]] = None,
380-
) -> torch.fx.GraphModule:
381-
# TODO: implement the support for None to be canceling out previous annotations
382-
if quantization_config is None:
383-
return model
384-
385-
for op in self.DYNAMIC_OPS:
386-
OP_TO_ANNOTATOR[op](model, quantization_config, filter_fn)
387393
return model
388394

389-
def _annotate_for_static_quantization_config(
395+
def _annotate_for_quantization_config(
390396
self, model: torch.fx.GraphModule
391397
) -> torch.fx.GraphModule:
392398
module_name_list = list(self.module_name_config.keys())
393399
for module_name, config in self.module_name_config.items():
394-
self._annotate_all_static_patterns(
400+
self._annotate_all_patterns(
395401
model, config, _get_module_name_filter(module_name)
396402
)
397403

398404
tp_list = list(self.module_type_config.keys())
399405
for module_type, config in self.module_type_config.items():
400-
self._annotate_all_static_patterns(
406+
self._annotate_all_patterns(
401407
model, config, _get_module_type_filter(module_type)
402408
)
403409

404-
self._annotate_all_static_patterns(
405-
model,
406-
self.global_config,
407-
_get_not_module_type_or_name_filter(tp_list, module_name_list),
408-
)
409-
return model
410-
411-
def _annotate_for_dynamic_quantization_config(
412-
self, model: torch.fx.GraphModule
413-
) -> torch.fx.GraphModule:
414-
module_name_list = list(self.module_name_config.keys())
415-
for module_name, config in self.module_name_config.items():
416-
self._annotate_all_dynamic_patterns(
417-
model, config, _get_module_name_filter(module_name)
418-
)
419-
420-
tp_list = list(self.module_type_config.keys())
421-
for module_type, config in self.module_type_config.items():
422-
self._annotate_all_dynamic_patterns(
423-
model, config, _get_module_type_filter(module_type)
410+
for op, config in self.operator_type_config.items():
411+
self._annotate_all_patterns(
412+
model,
413+
config,
414+
_get_not_module_type_or_name_filter(tp_list, module_name_list),
415+
op,
424416
)
425-
426-
self._annotate_all_dynamic_patterns(
417+
self._annotate_all_patterns(
427418
model,
428419
self.global_config,
429420
_get_not_module_type_or_name_filter(tp_list, module_name_list),

backends/xnnpack/test/quantizer/test_pt2e_quantization.py

+12-4
Original file line numberDiff line numberDiff line change
@@ -172,10 +172,14 @@ def test_composable_quantizer_linear_conv(self) -> None:
172172
quantization_config_dynamic = get_symmetric_quantization_config(
173173
is_per_channel=False, is_dynamic=True
174174
)
175-
dynamic_quantizer.set_global(quantization_config_dynamic)
175+
dynamic_quantizer.set_operator_type(
176+
torch.ops.aten.linear.default, quantization_config_dynamic
177+
)
176178
static_quantizer = XNNPACKQuantizer()
177179
quantization_config = get_symmetric_quantization_config(is_per_channel=True)
178-
static_quantizer.set_global(quantization_config)
180+
static_quantizer.set_operator_type(
181+
torch.ops.aten.conv2d.default, quantization_config
182+
)
179183
# Note that dynamic quantization must be applied first here.
180184
# this is because static quantizer also quantizes linear with static qspec
181185
# and if we apply static_quantizer first then dynamic_quantizer cannot be applied
@@ -271,10 +275,14 @@ def test_embedding_conv_linear_quantization(self) -> None:
271275
quantization_config_dynamic = get_symmetric_quantization_config(
272276
is_per_channel=True, is_dynamic=True
273277
)
274-
dynamic_quantizer.set_global(quantization_config_dynamic)
278+
dynamic_quantizer.set_operator_type(
279+
torch.ops.aten.linear.default, quantization_config_dynamic
280+
)
275281
static_quantizer = XNNPACKQuantizer()
276282
quantization_config = get_symmetric_quantization_config(is_per_channel=True)
277-
static_quantizer.set_global(quantization_config)
283+
static_quantizer.set_operator_type(
284+
torch.ops.aten.conv2d.default, quantization_config
285+
)
278286
composed_quantizer = ComposableQuantizer(
279287
[embedding_quantizer, dynamic_quantizer, static_quantizer]
280288
)

backends/xnnpack/test/quantizer/test_xnnpack_quantizer.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -665,7 +665,7 @@ def test_dynamic_linear_with_conv(self):
665665
quantization_config = get_symmetric_quantization_config(
666666
is_per_channel=False, is_dynamic=True
667667
)
668-
quantizer.set_global(quantization_config)
668+
quantizer.set_operator_type(torch.ops.aten.linear.default, quantization_config)
669669
m_eager = TestHelperModules.ConvLinearWPermute().eval()
670670

671671
node_occurrence = {

0 commit comments

Comments
 (0)