Skip to content

Commit 1af1d11

Browse files
authored
[Backend Tester] Add quantized test flows for XNNPACK and Core ML (#12733)
Add quantized test flows core static int8 quantization for XNNPACK and Core ML backends. I also ended up doing some light refactoring on the test signature to pass the TestFlow class into the individual tests. This is done to allow for passing quantization parameters into the inner test.
1 parent 8f062d3 commit 1af1d11

31 files changed

+485
-366
lines changed

backends/apple/coreml/test/tester.py

Lines changed: 67 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,23 +4,73 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
from typing import Any, List, Optional, Tuple
7+
import functools
8+
from typing import Any, List, Optional, Sequence, Tuple
89

10+
import coremltools as ct
911
import executorch
1012
import executorch.backends.test.harness.stages as BaseStages
11-
1213
import torch
14+
15+
from executorch.backends.apple.coreml.compiler import CoreMLBackend
1316
from executorch.backends.apple.coreml.partition import CoreMLPartitioner
17+
from executorch.backends.apple.coreml.quantizer import CoreMLQuantizer
1418
from executorch.backends.test.harness import Tester as TesterBase
1519
from executorch.backends.test.harness.stages import StageType
1620
from executorch.exir import EdgeCompileConfig
1721
from executorch.exir.backend.partitioner import Partitioner
1822

1923

24+
def _create_default_partitioner(
25+
minimum_deployment_target: Any = ct.target.iOS15,
26+
) -> CoreMLPartitioner:
27+
return CoreMLPartitioner(
28+
compile_specs=CoreMLBackend.generate_compile_specs(
29+
minimum_deployment_target=minimum_deployment_target
30+
)
31+
)
32+
33+
34+
def _get_static_int8_linear_qconfig():
35+
return ct.optimize.torch.quantization.LinearQuantizerConfig(
36+
global_config=ct.optimize.torch.quantization.ModuleLinearQuantizerConfig(
37+
quantization_scheme="symmetric",
38+
activation_dtype=torch.quint8,
39+
weight_dtype=torch.qint8,
40+
weight_per_channel=True,
41+
)
42+
)
43+
44+
45+
class Quantize(BaseStages.Quantize):
46+
def __init__(
47+
self,
48+
quantizer: Optional[CoreMLQuantizer] = None,
49+
quantization_config: Optional[Any] = None,
50+
calibrate: bool = True,
51+
calibration_samples: Optional[Sequence[Any]] = None,
52+
is_qat: Optional[bool] = False,
53+
):
54+
super().__init__(
55+
quantizer=quantizer
56+
or CoreMLQuantizer(
57+
quantization_config or _get_static_int8_linear_qconfig()
58+
),
59+
calibrate=calibrate,
60+
calibration_samples=calibration_samples,
61+
is_qat=is_qat,
62+
)
63+
64+
2065
class Partition(BaseStages.Partition):
21-
def __init__(self, partitioner: Optional[Partitioner] = None):
66+
def __init__(
67+
self,
68+
partitioner: Optional[Partitioner] = None,
69+
minimum_deployment_target: Optional[Any] = ct.target.iOS15,
70+
):
2271
super().__init__(
23-
partitioner=partitioner or CoreMLPartitioner,
72+
partitioner=partitioner
73+
or _create_default_partitioner(minimum_deployment_target),
2474
)
2575

2676

@@ -29,9 +79,12 @@ def __init__(
2979
self,
3080
partitioners: Optional[List[Partitioner]] = None,
3181
edge_compile_config: Optional[EdgeCompileConfig] = None,
82+
minimum_deployment_target: Optional[Any] = ct.target.iOS15,
3283
):
3384
super().__init__(
34-
default_partitioner_cls=CoreMLPartitioner,
85+
default_partitioner_cls=lambda: _create_default_partitioner(
86+
minimum_deployment_target
87+
),
3588
partitioners=partitioners,
3689
edge_compile_config=edge_compile_config,
3790
)
@@ -43,13 +96,20 @@ def __init__(
4396
module: torch.nn.Module,
4497
example_inputs: Tuple[torch.Tensor],
4598
dynamic_shapes: Optional[Tuple[Any]] = None,
99+
minimum_deployment_target: Optional[Any] = ct.target.iOS15,
46100
):
47101
# Specialize for XNNPACK
48102
stage_classes = (
49103
executorch.backends.test.harness.Tester.default_stage_classes()
50104
| {
51-
StageType.PARTITION: Partition,
52-
StageType.TO_EDGE_TRANSFORM_AND_LOWER: ToEdgeTransformAndLower,
105+
StageType.QUANTIZE: Quantize,
106+
StageType.PARTITION: functools.partial(
107+
Partition, minimum_deployment_target=minimum_deployment_target
108+
),
109+
StageType.TO_EDGE_TRANSFORM_AND_LOWER: functools.partial(
110+
ToEdgeTransformAndLower,
111+
minimum_deployment_target=minimum_deployment_target,
112+
),
53113
}
54114
)
55115

backends/test/harness/stages/quantize.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,13 +25,15 @@ def __init__(
2525
calibrate: bool = True,
2626
calibration_samples: Optional[Sequence[Any]] = None,
2727
is_qat: Optional[bool] = False,
28+
set_global: bool = True,
2829
):
2930
self.quantizer = quantizer
3031
self.quantization_config = quantization_config
3132
self.calibrate = calibrate
3233
self.calibration_samples = calibration_samples
3334

34-
self.quantizer.set_global(self.quantization_config)
35+
if self.quantization_config is not None and set_global:
36+
self.quantizer.set_global(self.quantization_config)
3537

3638
self.converted_graph = None
3739
self.is_qat = is_qat

backends/test/harness/tester.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import random
22
from collections import Counter, OrderedDict
3-
from typing import Any, Dict, List, Optional, Tuple, Type
3+
from typing import Any, Callable, Dict, List, Optional, Tuple
44

55
import torch
66

@@ -33,7 +33,7 @@ def __init__(
3333
self,
3434
module: torch.nn.Module,
3535
example_inputs: Tuple[torch.Tensor],
36-
stage_classes: Dict[StageType, Type],
36+
stage_classes: Dict[StageType, Callable],
3737
dynamic_shapes: Optional[Tuple[Any]] = None,
3838
):
3939
module.eval()
@@ -81,7 +81,7 @@ def __init__(
8181
self.stage_output = None
8282

8383
@staticmethod
84-
def default_stage_classes() -> Dict[StageType, Type]:
84+
def default_stage_classes() -> Dict[StageType, Callable]:
8585
"""
8686
Returns a map of StageType to default Stage implementation.
8787
"""

backends/test/suite/__init__.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ def _make_wrapped_test(
129129
def wrapped_test(self):
130130
with TestContext(test_name, flow.name, params):
131131
test_kwargs = params or {}
132-
test_kwargs["tester_factory"] = flow.tester_factory
132+
test_kwargs["flow"] = flow
133133

134134
test_func(self, **test_kwargs)
135135

@@ -175,7 +175,7 @@ def load_tests(loader, suite, pattern):
175175

176176

177177
class OperatorTest(unittest.TestCase):
178-
def _test_op(self, model, inputs, tester_factory):
178+
def _test_op(self, model, inputs, flow: TestFlow):
179179
context = get_active_test_context()
180180

181181
# This should be set in the wrapped test. See _make_wrapped_test above.
@@ -184,9 +184,8 @@ def _test_op(self, model, inputs, tester_factory):
184184
run_summary = run_test(
185185
model,
186186
inputs,
187-
tester_factory,
187+
flow,
188188
context.test_name,
189-
context.flow_name,
190189
context.params,
191190
)
192191

backends/test/suite/flow.py

Lines changed: 30 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
import logging
22

3-
from dataclasses import dataclass
3+
from dataclasses import dataclass, field
44
from typing import Callable
55

66
from executorch.backends.test.harness import Tester
7+
from executorch.backends.test.harness.stages import Quantize
78

89
logger = logging.getLogger(__name__)
910
logger.setLevel(logging.INFO)
@@ -22,41 +23,43 @@ class TestFlow:
2223
backend: str
2324
""" The name of the target backend. """
2425

25-
tester_factory: Callable[[], Tester]
26+
tester_factory: Callable[..., Tester]
2627
""" A factory function that returns a Tester instance for this lowering flow. """
2728

29+
quantize: bool = field(default=False)
30+
""" Whether to tester should run the quantize stage on the model. """
2831

29-
def create_xnnpack_flow() -> TestFlow | None:
30-
try:
31-
from executorch.backends.xnnpack.test.tester import Tester as XnnpackTester
32+
quantize_stage_factory: Callable[..., Quantize] | None = None
33+
""" A factory function which instantiates a Quantize stage. Can be None to use the tester's default. """
3234

33-
return TestFlow(
34-
name="xnnpack",
35-
backend="xnnpack",
36-
tester_factory=XnnpackTester,
37-
)
38-
except Exception:
39-
logger.info("Skipping XNNPACK flow registration due to import failure.")
40-
return None
4135

36+
def all_flows() -> dict[str, TestFlow]:
37+
flows = []
4238

43-
def create_coreml_flow() -> TestFlow | None:
4439
try:
45-
from executorch.backends.apple.coreml.test.tester import CoreMLTester
40+
from executorch.backends.test.suite.flows.xnnpack import (
41+
XNNPACK_STATIC_INT8_PER_CHANNEL_TEST_FLOW,
42+
XNNPACK_TEST_FLOW,
43+
)
4644

47-
return TestFlow(
48-
name="coreml",
49-
backend="coreml",
50-
tester_factory=CoreMLTester,
45+
flows += [
46+
XNNPACK_TEST_FLOW,
47+
XNNPACK_STATIC_INT8_PER_CHANNEL_TEST_FLOW,
48+
]
49+
except Exception as e:
50+
logger.info(f"Skipping XNNPACK flow registration: {e}")
51+
52+
try:
53+
from executorch.backends.test.suite.flows.coreml import (
54+
COREML_STATIC_INT8_TEST_FLOW,
55+
COREML_TEST_FLOW,
5156
)
52-
except Exception:
53-
logger.info("Skipping Core ML flow registration due to import failure.")
54-
return None
5557

58+
flows += [
59+
COREML_TEST_FLOW,
60+
COREML_STATIC_INT8_TEST_FLOW,
61+
]
62+
except Exception as e:
63+
logger.info(f"Skipping Core ML flow registration: {e}")
5664

57-
def all_flows() -> dict[str, TestFlow]:
58-
flows = [
59-
create_xnnpack_flow(),
60-
create_coreml_flow(),
61-
]
6265
return {f.name: f for f in flows if f is not None}

backends/test/suite/flows/__init__.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# pyre-unsafe

backends/test/suite/flows/coreml.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
import functools
2+
from typing import Any
3+
4+
import coremltools
5+
6+
from executorch.backends.apple.coreml.test.tester import CoreMLTester
7+
from executorch.backends.test.suite.flow import TestFlow
8+
9+
10+
def _create_coreml_flow(
11+
name: str,
12+
quantize: bool = False,
13+
minimum_deployment_target: Any = coremltools.target.iOS15,
14+
) -> TestFlow:
15+
return TestFlow(
16+
name,
17+
backend="coreml",
18+
tester_factory=functools.partial(
19+
CoreMLTester, minimum_deployment_target=minimum_deployment_target
20+
),
21+
quantize=quantize,
22+
)
23+
24+
25+
COREML_TEST_FLOW = _create_coreml_flow("coreml")
26+
COREML_STATIC_INT8_TEST_FLOW = _create_coreml_flow(
27+
"coreml_static_int8",
28+
quantize=True,
29+
minimum_deployment_target=coremltools.target.iOS17,
30+
)

backends/test/suite/flows/xnnpack.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
import logging
2+
from typing import Callable
3+
4+
from executorch.backends.test.harness.stages import Quantize
5+
from executorch.backends.test.suite.flow import TestFlow
6+
from executorch.backends.xnnpack.quantizer.xnnpack_quantizer import (
7+
get_symmetric_quantization_config,
8+
)
9+
from executorch.backends.xnnpack.test.tester import (
10+
Quantize as XnnpackQuantize,
11+
Tester as XnnpackTester,
12+
)
13+
14+
logger = logging.getLogger(__name__)
15+
logger.setLevel(logging.INFO)
16+
17+
18+
def _create_xnnpack_flow_base(
19+
name: str, quantize_stage_factory: Callable[..., Quantize] | None = None
20+
) -> TestFlow:
21+
return TestFlow(
22+
name,
23+
backend="xnnpack",
24+
tester_factory=XnnpackTester,
25+
quantize=quantize_stage_factory is not None,
26+
quantize_stage_factory=quantize_stage_factory,
27+
)
28+
29+
30+
def _create_xnnpack_flow() -> TestFlow:
31+
return _create_xnnpack_flow_base("xnnpack")
32+
33+
34+
def _create_xnnpack_static_int8_per_channel_flow() -> TestFlow:
35+
def create_quantize_stage() -> Quantize:
36+
qparams = get_symmetric_quantization_config(is_per_channel=True)
37+
return XnnpackQuantize(
38+
quantization_config=qparams,
39+
)
40+
41+
return _create_xnnpack_flow_base(
42+
"xnnpack_static_int8_per_channel", create_quantize_stage
43+
)
44+
45+
46+
XNNPACK_TEST_FLOW = _create_xnnpack_flow()
47+
XNNPACK_STATIC_INT8_PER_CHANNEL_TEST_FLOW = (
48+
_create_xnnpack_static_int8_per_channel_flow()
49+
)

backends/test/suite/models/__init__.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
from typing import Any, Callable
1313

1414
import torch
15-
from executorch.backends.test.harness import Tester
1615
from executorch.backends.test.suite import get_test_flows
1716
from executorch.backends.test.suite.context import get_active_test_context, TestContext
1817
from executorch.backends.test.suite.flow import TestFlow
@@ -49,7 +48,7 @@ def wrapped_test(self):
4948
"use_dynamic_shapes": use_dynamic_shapes,
5049
}
5150
with TestContext(test_name, flow.name, params):
52-
test_func(self, dtype, use_dynamic_shapes, flow.tester_factory)
51+
test_func(self, flow, dtype, use_dynamic_shapes)
5352

5453
dtype_name = str(dtype)[6:] # strip "torch."
5554
test_name = f"{test_func.__name__}_{flow.name}_{dtype_name}"
@@ -104,9 +103,9 @@ def inner_decorator(func: Callable) -> Callable:
104103
def run_model_test(
105104
model: torch.nn.Module,
106105
inputs: tuple[Any],
106+
flow: TestFlow,
107107
dtype: torch.dtype,
108108
dynamic_shapes: Any | None,
109-
tester_factory: Callable[[], Tester],
110109
):
111110
model = model.to(dtype)
112111
context = get_active_test_context()
@@ -117,9 +116,8 @@ def run_model_test(
117116
run_summary = run_test(
118117
model,
119118
inputs,
120-
tester_factory,
119+
flow,
121120
context.test_name,
122-
context.flow_name,
123121
context.params,
124122
dynamic_shapes=dynamic_shapes,
125123
)

0 commit comments

Comments
 (0)