|
3 | 3 |
|
4 | 4 | import copy
|
5 | 5 | import functools
|
6 |
| -from typing import Any, Callable, Optional, TYPE_CHECKING |
| 6 | +from dataclasses import dataclass |
| 7 | +from typing import Any, Callable, Optional, Set, TYPE_CHECKING |
7 | 8 |
|
8 | 9 | import torch
|
9 | 10 | import torch._dynamo as torchdynamo
|
@@ -235,37 +236,52 @@ def not_module_type_or_name_filter(n: Node) -> bool:
|
235 | 236 | return not_module_type_or_name_filter
|
236 | 237 |
|
237 | 238 |
|
238 |
| -class XNNPACKQuantizer(Quantizer): |
239 |
| - supported_config_and_operators = _get_supported_config_and_operators() |
240 |
| - STATIC_QAT_ONLY_OPS = [ |
241 |
| - "conv_bn_relu", |
242 |
| - "conv_bn", |
243 |
| - "conv_transpose_bn_relu", |
244 |
| - "conv_transpose_bn", |
245 |
| - ] |
| 239 | +@dataclass |
| 240 | +class QuantPattern: |
| 241 | + name: str |
| 242 | + is_dynamic: bool |
| 243 | + is_qat: bool |
| 244 | + op_overloads: Set[torch._ops.OpOverloadPacket] |
| 245 | + |
| 246 | + |
| 247 | +CONV_TARGETS = { |
| 248 | + torch.ops.aten.conv2d.default, |
| 249 | + torch.ops.aten.conv1d.default, |
| 250 | + torch.ops.aten.convolution.default, |
| 251 | +} |
| 252 | + |
| 253 | +LINEAR_TARGETS = { |
| 254 | + torch.ops.aten.linear.default, |
| 255 | +} |
| 256 | + |
| 257 | +ADAPTIVE_AVG_POOL2D_TARGETS = {torch.ops.aten.adaptive_avg_pool2d.default} |
| 258 | + |
| 259 | +ADD_TARGETS = {torch.ops.aten.add.Tensor} |
| 260 | + |
| 261 | +MUL_TARGETS = {torch.ops.aten.mul.Tensor} |
| 262 | + |
| 263 | +CAT_TARGETS = {torch.ops.aten.cat.default} |
246 | 264 |
|
247 |
| - # static quantization ops (both PTQ and QAT) |
248 |
| - # Preserve the order that fusions come before singular ops |
249 |
| - STATIC_OPS = [ |
250 |
| - "linear_relu", |
251 |
| - "linear", |
252 |
| - "conv", |
253 |
| - "conv_transpose", |
254 |
| - "conv_relu", |
255 |
| - "conv_transpose_relu", |
256 |
| - "adaptive_avg_pool2d", |
257 |
| - # TODO: move this to BoltNNQuantizer? |
258 |
| - "gru_io_only", |
259 |
| - "add_relu", |
260 |
| - "add", |
261 |
| - "mul_relu", |
262 |
| - "mul", |
263 |
| - "cat", |
264 |
| - ] |
265 | 265 |
|
266 |
| - DYNAMIC_OPS = [ |
267 |
| - "linear", |
268 |
| - "conv", |
| 266 | +class XNNPACKQuantizer(Quantizer): |
| 267 | + supported_config_and_operators = _get_supported_config_and_operators() |
| 268 | + SUPPORTED_PATTERNS = [ |
| 269 | + QuantPattern("conv_bn_relu", False, True, CONV_TARGETS), |
| 270 | + QuantPattern("conv_bn", False, True, CONV_TARGETS), |
| 271 | + QuantPattern("conv_transpose_bn_relu", False, True, CONV_TARGETS), |
| 272 | + QuantPattern("conv_transpose_bn", False, True, CONV_TARGETS), |
| 273 | + QuantPattern("linear_relu", False, False, LINEAR_TARGETS), |
| 274 | + QuantPattern("linear", True, False, LINEAR_TARGETS), |
| 275 | + QuantPattern("conv", True, False, CONV_TARGETS), |
| 276 | + QuantPattern("conv_transpose", False, False, CONV_TARGETS), |
| 277 | + QuantPattern("conv_relu", False, False, CONV_TARGETS), |
| 278 | + QuantPattern("conv_transpose_relu", False, False, CONV_TARGETS), |
| 279 | + QuantPattern("adaptive_avg_pool2d", False, False, ADAPTIVE_AVG_POOL2D_TARGETS), |
| 280 | + QuantPattern("add_relu", False, False, ADD_TARGETS), |
| 281 | + QuantPattern("add", False, False, ADD_TARGETS), |
| 282 | + QuantPattern("mul_relu", False, False, MUL_TARGETS), |
| 283 | + QuantPattern("mul", False, False, MUL_TARGETS), |
| 284 | + QuantPattern("cat", False, False, CAT_TARGETS), |
269 | 285 | ]
|
270 | 286 |
|
271 | 287 | def __init__(self) -> None:
|
@@ -347,83 +363,58 @@ def transform_for_annotation(
|
347 | 363 |
|
348 | 364 | def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
|
349 | 365 | """just handling global spec for now"""
|
350 |
| - # hacked for handling dynamic linear quant. will fix later. |
351 |
| - if self.global_config and self.global_config.input_activation.is_dynamic: # type: ignore[union-attr] |
352 |
| - model = self._annotate_for_dynamic_quantization_config(model) |
353 |
| - else: |
354 |
| - model = self._annotate_for_static_quantization_config(model) |
| 366 | + model = self._annotate_for_quantization_config(model) |
355 | 367 | propagate_annotation(model)
|
356 | 368 | return model
|
357 | 369 |
|
358 |
| - def _annotate_all_static_patterns( |
| 370 | + def _annotate_all_patterns( |
359 | 371 | self,
|
360 | 372 | model: torch.fx.GraphModule,
|
361 | 373 | quantization_config: Optional[QuantizationConfig],
|
362 | 374 | filter_fn: Optional[Callable[[Node], bool]] = None,
|
363 |
| - ) -> torch.fx.GraphModule: |
| 375 | + operator_target: Optional[torch._ops.OpOverloadPacket] = None, |
| 376 | + ): |
364 | 377 | # TODO: implement the support for None to be canceling out previous annotations
|
365 | 378 | if quantization_config is None:
|
366 | 379 | return model
|
367 | 380 |
|
368 |
| - if quantization_config.is_qat: |
369 |
| - for op in self.STATIC_QAT_ONLY_OPS: |
370 |
| - OP_TO_ANNOTATOR[op](model, quantization_config, filter_fn) |
371 |
| - for op in self.STATIC_OPS: |
372 |
| - OP_TO_ANNOTATOR[op](model, quantization_config, filter_fn) |
373 |
| - return model |
| 381 | + for pattern in self.SUPPORTED_PATTERNS: |
| 382 | + if operator_target and operator_target not in pattern.op_overloads: |
| 383 | + # if operator_target is specified, skip patterns that aren't |
| 384 | + # associated with that target |
| 385 | + continue |
| 386 | + if quantization_config.input_activation.is_dynamic and pattern.is_dynamic: |
| 387 | + OP_TO_ANNOTATOR[pattern.name](model, quantization_config, filter_fn) |
| 388 | + elif quantization_config.is_qat and pattern.is_qat: |
| 389 | + OP_TO_ANNOTATOR[pattern.name](model, quantization_config, filter_fn) |
| 390 | + elif not quantization_config.input_activation.is_dynamic: |
| 391 | + OP_TO_ANNOTATOR[pattern.name](model, quantization_config, filter_fn) |
374 | 392 |
|
375 |
| - def _annotate_all_dynamic_patterns( |
376 |
| - self, |
377 |
| - model: torch.fx.GraphModule, |
378 |
| - quantization_config: Optional[QuantizationConfig], |
379 |
| - filter_fn: Optional[Callable[[Node], bool]] = None, |
380 |
| - ) -> torch.fx.GraphModule: |
381 |
| - # TODO: implement the support for None to be canceling out previous annotations |
382 |
| - if quantization_config is None: |
383 |
| - return model |
384 |
| - |
385 |
| - for op in self.DYNAMIC_OPS: |
386 |
| - OP_TO_ANNOTATOR[op](model, quantization_config, filter_fn) |
387 | 393 | return model
|
388 | 394 |
|
389 |
| - def _annotate_for_static_quantization_config( |
| 395 | + def _annotate_for_quantization_config( |
390 | 396 | self, model: torch.fx.GraphModule
|
391 | 397 | ) -> torch.fx.GraphModule:
|
392 | 398 | module_name_list = list(self.module_name_config.keys())
|
393 | 399 | for module_name, config in self.module_name_config.items():
|
394 |
| - self._annotate_all_static_patterns( |
| 400 | + self._annotate_all_patterns( |
395 | 401 | model, config, _get_module_name_filter(module_name)
|
396 | 402 | )
|
397 | 403 |
|
398 | 404 | tp_list = list(self.module_type_config.keys())
|
399 | 405 | for module_type, config in self.module_type_config.items():
|
400 |
| - self._annotate_all_static_patterns( |
| 406 | + self._annotate_all_patterns( |
401 | 407 | model, config, _get_module_type_filter(module_type)
|
402 | 408 | )
|
403 | 409 |
|
404 |
| - self._annotate_all_static_patterns( |
405 |
| - model, |
406 |
| - self.global_config, |
407 |
| - _get_not_module_type_or_name_filter(tp_list, module_name_list), |
408 |
| - ) |
409 |
| - return model |
410 |
| - |
411 |
| - def _annotate_for_dynamic_quantization_config( |
412 |
| - self, model: torch.fx.GraphModule |
413 |
| - ) -> torch.fx.GraphModule: |
414 |
| - module_name_list = list(self.module_name_config.keys()) |
415 |
| - for module_name, config in self.module_name_config.items(): |
416 |
| - self._annotate_all_dynamic_patterns( |
417 |
| - model, config, _get_module_name_filter(module_name) |
418 |
| - ) |
419 |
| - |
420 |
| - tp_list = list(self.module_type_config.keys()) |
421 |
| - for module_type, config in self.module_type_config.items(): |
422 |
| - self._annotate_all_dynamic_patterns( |
423 |
| - model, config, _get_module_type_filter(module_type) |
| 410 | + for op, config in self.operator_type_config.items(): |
| 411 | + self._annotate_all_patterns( |
| 412 | + model, |
| 413 | + config, |
| 414 | + _get_not_module_type_or_name_filter(tp_list, module_name_list), |
| 415 | + op, |
424 | 416 | )
|
425 |
| - |
426 |
| - self._annotate_all_dynamic_patterns( |
| 417 | + self._annotate_all_patterns( |
427 | 418 | model,
|
428 | 419 | self.global_config,
|
429 | 420 | _get_not_module_type_or_name_filter(tp_list, module_name_list),
|
|
0 commit comments