Skip to content

Commit 264ac90

Browse files
authored
[Backend Tester] Add FACTO operator test skeleton (#11953)
Add initial skeleton for running Facto operator tests. This is only set up for XNNPACK in this commit (other delegates are later in the stack). It also relies on a manual install of Facto. I'm currently intentionally not running this in CI, due to a combination of finding a number of crashes and due to a high volume of tests. This will be addressed further up the stack. Instructions for running the tests locally are included at the top of test_facto.py.
1 parent 9327708 commit 264ac90

File tree

6 files changed

+366
-10
lines changed

6 files changed

+366
-10
lines changed

backends/test/harness/tester.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -361,6 +361,7 @@ def _assert_outputs_equal(model_output, ref_output, atol=1e-03, rtol=1e-03):
361361
ref,
362362
atol=atol,
363363
rtol=rtol,
364+
equal_nan=True,
364365
), (
365366
f"Output {i} does not match reference output.\n"
366367
f"\tGiven atol: {atol}, rtol: {rtol}.\n"

backends/test/operators/__init__.py

Whitespace-only changes.
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
import facto.specdb.function as fn
2+
import torch
3+
4+
from facto.inputgen.argument.type import ArgType
5+
from facto.inputgen.specs.model import ConstraintProducer as cp, InPosArg, OutArg, Spec
6+
7+
"""
8+
This file contains FACTO operator specs for ops not in the standard FACTO db. This mainly
9+
includes ops not in the Core ATen op set and preserved by a backend, such as linear.
10+
"""
11+
12+
LINEAR_DEFAULT_SPEC = Spec(
13+
op="linear.default", # (Tensor input, Tensor weight, Tensor? bias=None) -> Tensor
14+
inspec=[
15+
InPosArg(
16+
ArgType.Tensor,
17+
name="input",
18+
deps=[1, 2],
19+
constraints=[
20+
cp.Dtype.Eq(lambda deps: deps[0].dtype),
21+
cp.Rank.Ge(lambda deps: 2),
22+
cp.Size.In(
23+
lambda deps, r, d: fn.broadcast_to(
24+
(fn.safe_size(deps[0], 0), fn.safe_size(deps[1], 1)), r, d
25+
)
26+
),
27+
],
28+
),
29+
InPosArg(
30+
ArgType.Tensor,
31+
name="weight",
32+
constraints=[
33+
cp.Dtype.Ne(lambda deps: torch.bool),
34+
cp.Rank.Eq(lambda deps: 2),
35+
],
36+
),
37+
InPosArg(
38+
ArgType.Tensor,
39+
name="bias",
40+
deps=[1],
41+
constraints=[
42+
cp.Dtype.Eq(lambda deps: deps[0].dtype),
43+
cp.Rank.Eq(lambda deps: 2),
44+
cp.Size.Eq(
45+
lambda deps, r, d: fn.safe_size(deps[0], 1) if d == 0 else None
46+
),
47+
],
48+
),
49+
],
50+
outspec=[
51+
OutArg(ArgType.Tensor),
52+
],
53+
)
54+
55+
_extra_specs = [
56+
LINEAR_DEFAULT_SPEC,
57+
]
58+
59+
ExtraSpecDB: dict[str, Spec] = {s.op: s for s in _extra_specs}

backends/test/operators/test_facto.py

Lines changed: 292 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,292 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# pyre-strict
8+
9+
import copy
10+
import functools
11+
import traceback
12+
import unittest
13+
from typing import Any, Callable, Sequence
14+
15+
import torch
16+
from executorch.backends.test.harness.tester import Tester as TesterBase
17+
from executorch.backends.xnnpack.test.tester.tester import Tester as XnnpackTester
18+
from facto.inputgen.argtuple.gen import ArgumentTupleGenerator
19+
from facto.inputgen.specs.model import ConstraintProducer as cp, Spec
20+
from facto.inputgen.utils.random_manager import random_manager
21+
from facto.specdb.db import SpecDictDB
22+
from torch._ops import OpOverload
23+
24+
from .facto_specs import ExtraSpecDB
25+
26+
CombinedSpecDB = SpecDictDB | ExtraSpecDB
27+
28+
COMMON_TENSOR_CONSTRAINTS = [
29+
cp.Rank.Ge(lambda deps: 1),
30+
cp.Rank.Le(lambda deps: 4),
31+
cp.Size.Ge(lambda deps, r, d: 1),
32+
cp.Size.Le(lambda deps, r, d: 2**9),
33+
]
34+
35+
COMMON_SCALAR_CONSTRAINS = [
36+
cp.Value.Ge(lambda deps, dtype: -1000),
37+
cp.Value.Le(lambda deps, dtype: 1000),
38+
]
39+
40+
# Operator args are treated as runtime graph inputs if the argument name is
41+
# in this list.
42+
RUNTIME_INPUT_NAMES = {
43+
"self",
44+
"tensor",
45+
"other",
46+
}
47+
48+
49+
def _patch_spec(spec: Spec) -> Spec:
50+
spec = copy.deepcopy(spec)
51+
for inspec in spec.inspec:
52+
if inspec.type.is_tensor():
53+
inspec.constraints.extend(COMMON_TENSOR_CONSTRAINTS)
54+
elif inspec.type.is_scalar():
55+
inspec.constraints.extend(COMMON_SCALAR_CONSTRAINS)
56+
return spec
57+
58+
59+
class OpModel(torch.nn.Module):
60+
"""
61+
Wraps a single torch operator in an nn.Module.
62+
"""
63+
64+
def __init__(
65+
self,
66+
op: OpOverload,
67+
runtime_input_count: int,
68+
fixed_args: Sequence[Any],
69+
fixed_kwargs: dict[str, Any],
70+
):
71+
super().__init__()
72+
self.op = op
73+
self.runtime_input_count = runtime_input_count
74+
self.fixed_kwargs = fixed_kwargs
75+
76+
# Register parameters for fixed tensors. Some things will choke on
77+
# constant tensor weights, for example.
78+
new_args = []
79+
for i, arg in enumerate(fixed_args):
80+
if isinstance(arg, torch.Tensor):
81+
param = torch.nn.Parameter(arg, requires_grad=False)
82+
param_name = f"arg_{i}_param"
83+
setattr(self, param_name, param)
84+
self.register_parameter(param_name, param)
85+
new_args.append(param)
86+
else:
87+
new_args.append(arg)
88+
self.fixed_args = tuple(new_args)
89+
90+
def forward(self, *args, **kwargs):
91+
return self.op(*(args + self.fixed_args), **(kwargs | self.fixed_kwargs))
92+
93+
94+
# The convolution model has some minor wrapper logic around the actual convolution
95+
# operator. Most of the backends are expecting this form.
96+
# TODO (gjcomer) Investigate these discrepencies.
97+
class ConvModel(OpModel):
98+
def forward(self, *args, **kwargs):
99+
weight, bias, stride, padding, dilation, transposed, output_padding, groups = (
100+
self.fixed_args
101+
)
102+
103+
if not transposed:
104+
if len(weight.shape) == 3:
105+
op = torch.nn.functional.conv1d
106+
elif len(weight.shape) == 4:
107+
op = torch.nn.functional.conv2d
108+
elif len(weight.shape) == 5:
109+
op = torch.nn.functional.conv3d
110+
111+
return op(args[0], weight, bias, stride, padding, dilation, groups)
112+
else:
113+
if len(weight.shape) == 3:
114+
op = torch.nn.functional.conv_transpose1d
115+
elif len(weight.shape) == 4:
116+
op = torch.nn.functional.conv_transpose2d
117+
elif len(weight.shape) == 5:
118+
op = torch.nn.functional.conv_transpose3d
119+
120+
return op(
121+
args[0], weight, bias, stride, padding, output_padding, groups, dilation
122+
)
123+
124+
125+
def get_module_for_op(op: OpOverload):
126+
if op == torch.ops.aten.convolution.default:
127+
return ConvModel
128+
else:
129+
return OpModel
130+
131+
132+
class FactoTestsBase(unittest.TestCase):
133+
def __init__(self, tester_factory: Callable[[], TesterBase], *args, **kwargs):
134+
super().__init__(*args, **kwargs)
135+
self._tester_factory = tester_factory
136+
137+
@staticmethod
138+
def _generate_test(op_name: str) -> None:
139+
# Find the torch op with the given name.
140+
sections = op_name.split(".")
141+
torch_op = functools.reduce(getattr, sections, torch.ops.aten)
142+
143+
test_name = "test_" + op_name.replace(".", "_")
144+
145+
def test_body(self):
146+
self._test_op(torch_op)
147+
148+
setattr(FactoTestsBase, test_name, test_body)
149+
150+
@staticmethod
151+
def get_runtime_input_count(spec: Spec):
152+
# Determine which inputs are fixed at tracing time (weights, for example),
153+
# vs inputs to the runtime graph. We currently assume that the runtime graph
154+
# inputs start at the beginning of the arg list and are contiguous.
155+
#
156+
# Args are consider to be runtime inputs if they are positional and are named
157+
# one of RUNTIME_INPUT_NAMES. If none match, we assume only the first arg is a
158+
# runtime input.
159+
runtime_input_count = 0
160+
for inspec in spec.inspec:
161+
is_runtime_input = (
162+
inspec.type.is_tensor() and inspec.name.lower() in RUNTIME_INPUT_NAMES
163+
)
164+
if is_runtime_input:
165+
runtime_input_count += 1
166+
else:
167+
break
168+
169+
return max(1, runtime_input_count)
170+
171+
def setUp(self):
172+
torch.set_printoptions(threshold=3)
173+
174+
def _test_op(self, op: OpOverload) -> None: # noqa
175+
random_manager.seed(0)
176+
177+
# Strip namespace
178+
op_name = op.name().split("::")[-1]
179+
180+
# Default to .default overload
181+
if "." not in op_name:
182+
op_name += ".default"
183+
184+
# Find and patch op spec
185+
if op_name not in CombinedSpecDB:
186+
raise ValueError(f"Operator {op_name} not found in SpecDictDB.")
187+
spec = _patch_spec(CombinedSpecDB[op_name])
188+
189+
runtime_input_count = FactoTestsBase.get_runtime_input_count(spec)
190+
191+
print(f"Op: {op_name}, {runtime_input_count} runtime inputs")
192+
193+
# Run test cases
194+
success_count_delegated = 0
195+
success_count_undelegated = 0
196+
fail_count = 0
197+
198+
i = 0
199+
for posargs, inkwargs, _ in ArgumentTupleGenerator(spec).gen():
200+
i += 1
201+
202+
try:
203+
if isinstance(posargs[0], torch.Tensor):
204+
# Temporary for getting around XNN crashes (https://github.com/pytorch/executorch/issues/10960).
205+
# TODO Re-enable when resolved.
206+
if posargs[0].dtype in {torch.int8, torch.uint8}:
207+
print("Skipping (u)int8 case.")
208+
continue
209+
210+
module_cls = get_module_for_op(op)
211+
model = module_cls(
212+
op, runtime_input_count, posargs[runtime_input_count:], inkwargs
213+
)
214+
215+
# Sanity check to make sure it runs in eager. This can present nicer error
216+
# messages sometimes compared to tracing.
217+
try:
218+
model(*posargs[:runtime_input_count])
219+
except Exception as e:
220+
print(f"Eager execution failed: {e}")
221+
continue
222+
223+
tester = self._tester_factory(
224+
model, tuple(posargs[:runtime_input_count])
225+
)
226+
227+
# Dynamo will also fail to handle some patterns that are valid in eager.
228+
try:
229+
tester.export()
230+
except Exception:
231+
print("Export failed.")
232+
continue
233+
234+
tester.to_edge_transform_and_lower()
235+
236+
is_delegated = any(
237+
n.target == torch._higher_order_ops.executorch_call_delegate
238+
for n in tester.stages[tester.cur].graph_module.graph.nodes
239+
if n.op == "call_function"
240+
)
241+
242+
# Only run the runtime test if the op was delegated.
243+
if is_delegated:
244+
(
245+
tester.to_executorch()
246+
.serialize()
247+
.run_method_and_compare_outputs()
248+
)
249+
250+
if is_delegated:
251+
success_count_delegated += 1
252+
else:
253+
success_count_undelegated += 1
254+
except Exception:
255+
fail_count += 1
256+
print("Args:")
257+
for arg in posargs:
258+
if isinstance(arg, torch.Tensor):
259+
print(f" {arg.dtype} {arg.shape}")
260+
else:
261+
print(f" {arg}")
262+
263+
traceback.print_exc()
264+
265+
print(
266+
f"{success_count_delegated + success_count_undelegated} PASS, {fail_count} FAIL"
267+
)
268+
print(
269+
f" {success_count_delegated} DELEGATED, {success_count_undelegated} UNDELEGATED"
270+
)
271+
272+
273+
# Programatically generate tests for each operator.
274+
for op_name in CombinedSpecDB.keys():
275+
FactoTestsBase._generate_test(op_name)
276+
277+
278+
# TODO Figure out where to put these
279+
class FactoTestsXNNPACK(FactoTestsBase):
280+
def __init__(self, *args, **kwargs):
281+
super().__init__(XnnpackTester, *args, **kwargs)
282+
283+
284+
try:
285+
from executorch.backends.apple.coreml.test.tester import CoreMLTester
286+
287+
class FactoTestsCoreML(FactoTestsBase):
288+
def __init__(self, *args, **kwargs):
289+
super().__init__(CoreMLTester, *args, **kwargs)
290+
291+
except:
292+
print("Skipping Core ML facto tests as Core ML AOT is not available.")

backends/xnnpack/test/tester/__init__.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
# TODO: Be more delibrate on module structure
87
from executorch.backends.xnnpack.test.tester.tester import (
98
Export,
109
Partition,
@@ -18,13 +17,13 @@
1817
)
1918

2019
__all__ = [
21-
Export,
22-
ToEdge,
23-
Partition,
24-
Quantize,
25-
RunPasses,
26-
ToEdgeTransformAndLower,
27-
Tester,
28-
Serialize,
29-
ToExecutorch,
20+
"Export",
21+
"Partition",
22+
"Quantize",
23+
"RunPasses",
24+
"Serialize",
25+
"Tester",
26+
"ToEdge",
27+
"ToEdgeTransformAndLower",
28+
"ToExecutorch",
3029
]

0 commit comments

Comments
 (0)