Skip to content

Commit 38525d6

Browse files
authored
Add decompositions to import flow (#27)
Adds default decompositions from SHARK to the Turbine import flow, with a test for importing `aten.chunk` and `nn.BatchNorm2d` which previously failed.
1 parent a2a58bd commit 38525d6

File tree

9 files changed

+210
-73
lines changed

9 files changed

+210
-73
lines changed

python/shark_turbine/dynamo/backends/cpu.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737

3838
import torch
3939
from torch._dynamo.backends.common import aot_autograd
40+
from ..passes import turbine_cpu_pass_pipeline
4041

4142
DEFAULT_COMPILER_FLAGS = (
4243
# Enable asynchronous calling convention.
@@ -65,6 +66,9 @@ def _base_backend(gm: torch.fx.GraphModule, example_inputs):
6566
inv.enable_console_diagnostics()
6667
inv.import_module(module.operation)
6768

69+
# Apply decompositions.
70+
gm = turbine_cpu_pass_pipeline(gm, example_inputs)
71+
6872
# Import phase.
6973
importer.import_graph_module(gm)
7074
print(module, file=sys.stderr)

python/shark_turbine/dynamo/importer.py

Lines changed: 16 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -136,10 +136,10 @@ class FxImporter:
136136
]
137137

138138
def __init__(
139-
self,
140-
module: Optional[Module] = None,
141-
context: Optional[Context] = None,
142-
config_check: bool = True,
139+
self,
140+
module: Optional[Module] = None,
141+
context: Optional[Context] = None,
142+
config_check: bool = True,
143143
):
144144
if module is not None:
145145
assert context is None, "If configuring with a Module, context must be None"
@@ -214,7 +214,9 @@ def _graph_to_function_meta(self, g: Graph) -> Tuple[FunctionType, Location]:
214214
# always be "boxed" as a tuple, which we emit as multi-results.
215215
for result_node in node.args[0]:
216216
if result_node is None:
217-
result_types.append(MlirType.parse("!torch.none", context=self._c))
217+
result_types.append(
218+
MlirType.parse("!torch.none", context=self._c)
219+
)
218220
else:
219221
result_types.append(self._cc.node_val_to_type(result_node))
220222
return (
@@ -390,7 +392,7 @@ def import_nodes(self, nodes: Sequence[torch_fx.Node]):
390392
func_dialect.ReturnOp(operands, loc=loc)
391393

392394
def _import_torch_op_overload(
393-
self, loc: Location, node: torch_fx.Node, target: TorchOpOverload
395+
self, loc: Location, node: torch_fx.Node, target: TorchOpOverload
394396
):
395397
schema = target._schema
396398
assert isinstance(schema, FunctionSchema)
@@ -404,7 +406,7 @@ def _import_torch_op_overload(
404406

405407
# Intervening to use Scalar ops due to incorrect ops from AOT-autograd with scalar arguments.
406408
if mlir_op_name in TENSOR_SCALAR_OP_CONVERTER and (
407-
isinstance(node.args[1], float) or isinstance(node.args[1], int)
409+
isinstance(node.args[1], float) or isinstance(node.args[1], int)
408410
):
409411
mlir_op_name = TENSOR_SCALAR_OP_CONVERTER[mlir_op_name]
410412

@@ -487,9 +489,7 @@ def _import_list_argument(self, loc: Location, arg):
487489
result_type = SCALAR_TYPE_TO_TORCH_LIST_TYPE.get(arg_type, None)
488490

489491
if result_type is not None:
490-
result_type = MlirType.parse(
491-
result_type, context=self._c
492-
)
492+
result_type = MlirType.parse(result_type, context=self._c)
493493

494494
for operand in arg:
495495
operand_type = type(operand)
@@ -498,21 +498,6 @@ def _import_list_argument(self, loc: Location, arg):
498498
f"Lists with multiple types are not supported, got: {arg_type}, {operand_type}"
499499
)
500500

501-
if isinstance(operand, torch.fx.Node):
502-
if operand in self._multi_result_nodes:
503-
raise RuntimeError(f"Attempt to de-reference a multi-result node")
504-
val = self._v[(operand, 0)]
505-
if result_type is None:
506-
list_type: str = str(val.type)
507-
begin_index = 7 if list_type.startswith("!torch.") else None
508-
end_index = list_type.find("<")
509-
end_index = end_index if end_index != -1 else None
510-
list_type = list_type[begin_index:end_index]
511-
result_type = MlirType.parse(f"!torch.list<{list_type}>")
512-
else:
513-
val = self._import_default_value(
514-
loc, operand, SCALAR_TYPE_TO_TORCH_TYPE[type(operand)]
515-
)
516501
if isinstance(operand, torch.fx.Node):
517502
if operand in self._multi_result_nodes:
518503
raise RuntimeError(f"Attempt to de-reference a multi-result node")
@@ -522,7 +507,9 @@ def _import_list_argument(self, loc: Location, arg):
522507
pattern = r"^!torch\.(.*?)(?:<.*>)?$"
523508
val_type = str(val.type)
524509
match = re.match(pattern, val_type)
525-
assert match is not None, f"Unexpected MlirType in list: \'{val_type}\'"
510+
assert (
511+
match is not None
512+
), f"Unexpected MlirType in list: '{val_type}'"
526513
list_type = match.group(1)
527514
result_type = MlirType.parse(f"!torch.list<{list_type}>")
528515
else:
@@ -595,7 +582,7 @@ def lookup(self, t: type) -> Any:
595582

596583

597584
def _make_constant_op(
598-
op_name: str, value_attr: MlirAttribute, result_type: Optional[MlirType] = None
585+
op_name: str, value_attr: MlirAttribute, result_type: Optional[MlirType] = None
599586
) -> Operation:
600587
return Operation.create(
601588
op_name,
@@ -664,14 +651,14 @@ def _make_constant_op(
664651
int: "!torch.list<int>",
665652
float: "!torch.list<float>",
666653
str: "!torch.list<str>",
667-
bool: "!torch.list<bool>"
654+
bool: "!torch.list<bool>",
668655
}
669656

670657
SCALAR_TYPE_TO_TORCH_TYPE = {
671658
int: "!torch.int",
672659
float: "!torch.float",
673660
str: "!torch.str",
674-
bool: "!torch.bool"
661+
bool: "!torch.bool",
675662
}
676663

677664
# AOT-autograd sometimes falsely emit tensor version op with scalar arguments.

python/shark_turbine/dynamo/passes.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
import torch
2+
from torch.fx.experimental.proxy_tensor import make_fx
3+
from torch._decomp import get_decompositions
4+
from torch.func import functionalize
5+
from typing import List
6+
7+
# default decompositions pulled from SHARK
8+
DEFAULT_DECOMPOSITIONS = [
9+
torch.ops.aten.embedding_dense_backward,
10+
torch.ops.aten.native_layer_norm_backward,
11+
torch.ops.aten.slice_backward,
12+
torch.ops.aten.select_backward,
13+
torch.ops.aten.norm.ScalarOpt_dim,
14+
torch.ops.aten.native_group_norm,
15+
torch.ops.aten.upsample_bilinear2d.vec,
16+
torch.ops.aten.split.Tensor,
17+
torch.ops.aten.split_with_sizes,
18+
torch.ops.aten.native_layer_norm,
19+
torch.ops.aten.masked_fill.Tensor,
20+
torch.ops.aten.masked_fill.Scalar,
21+
]
22+
23+
# decompositions that aid us in handling nn.BatchNorm2d
24+
BATCHNORM_DECOMPOSITIONS = [
25+
torch.ops.aten._native_batch_norm_legit_functional,
26+
torch.ops.aten.squeeze.dims,
27+
]
28+
29+
30+
def apply_decompositions(
31+
gm: torch.fx.GraphModule,
32+
example_inputs,
33+
decompose_ops: List[torch._ops.OpOverload] = None,
34+
):
35+
if decompose_ops is None:
36+
return gm
37+
38+
decompositions = get_decompositions(decompose_ops)
39+
gm = make_fx(
40+
functionalize(gm),
41+
decomposition_table=decompositions,
42+
)(*example_inputs)
43+
44+
return gm
45+
46+
47+
def turbine_cpu_pass_pipeline(gm: torch.fx.GraphModule, example_inputs):
48+
decompose_ops = DEFAULT_DECOMPOSITIONS + BATCHNORM_DECOMPOSITIONS
49+
return apply_decompositions(gm, example_inputs, decompose_ops)

python/test/dynamo/importer_basic_test.py

Lines changed: 37 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,21 +6,31 @@
66

77
import logging
88
import unittest
9+
from typing import List
910

1011
from shark_turbine.dynamo.importer import FxImporter
1112
import torch
1213
import torch._dynamo as dynamo
1314
from torch._dynamo.backends.common import aot_autograd
15+
from torch.fx.experimental.proxy_tensor import make_fx
16+
from torch._decomp import get_decompositions
17+
from torch.func import functionalize
1418
from torch.fx import (
1519
GraphModule,
1620
)
1721

1822

1923
class ImportTests(unittest.TestCase):
20-
def create_backend(self):
24+
def create_backend(self, decompose_ops: List[torch._ops.OpOverloadPacket] = None):
2125
imp = FxImporter()
2226

2327
def import_compiler(gm: GraphModule, example_inputs):
28+
if decompose_ops is not None:
29+
gm = make_fx(
30+
functionalize(gm),
31+
decomposition_table=get_decompositions(decompose_ops),
32+
)(*example_inputs)
33+
2434
gm.print_readable()
2535
try:
2636
imp.import_graph_module(gm)
@@ -107,17 +117,35 @@ def foo(x, y):
107117
opt_foo = torch.compile(foo, backend=self.create_backend())
108118
opt_foo(torch.randn(10), torch.randn(10))
109119

110-
@unittest.expectedFailure
111-
def testImportChunk(self):
112-
"""
113-
Marked as XFail due to Unsupported placeholder node, where FX graph does not return meta_data["tensor_meta"]
114-
to create Ops. Same problem occurs with split.Tensor and unbind.int. Needs to identify the root cause.
115-
"""
116-
120+
def testImportDecomposeChunk(self):
117121
def foo_chunk(x):
118122
return torch.chunk(x, 2, dim=-1)
119123

120-
opt = torch.compile(foo_chunk, backend=self.create_backend())
124+
opt = torch.compile(
125+
foo_chunk,
126+
backend=self.create_backend(
127+
decompose_ops=[
128+
torch.ops.aten.split.Tensor,
129+
torch.ops.aten.split_with_sizes,
130+
]
131+
),
132+
)
133+
t = torch.randn([4, 4, 4, 4])
134+
opt(t)
135+
136+
def testImportDecomposeBatchNorm2D(self):
137+
def foo_chunk(x):
138+
return torch.nn.BatchNorm2d(4)(x)
139+
140+
opt = torch.compile(
141+
foo_chunk,
142+
backend=self.create_backend(
143+
decompose_ops=[
144+
torch.ops.aten._native_batch_norm_legit_functional,
145+
torch.ops.aten.squeeze.dims,
146+
]
147+
),
148+
)
121149
t = torch.randn([4, 4, 4, 4])
122150
opt(t)
123151

python/test/dynamo/multiple_aten_results_test.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,6 @@ def import_compiler(gm: GraphModule, example_inputs):
3434
import torch.nn.functional as F
3535

3636
class Scaled_Dot_Product_Attention(nn.Module):
37-
3837
def __init__(self):
3938
super(Scaled_Dot_Product_Attention, self).__init__()
4039

python/test/generated/evaluate.py

Lines changed: 33 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,23 +9,54 @@
99
GraphModule,
1010
)
1111

12+
from torch.fx.experimental.proxy_tensor import make_fx
13+
from torch._decomp import get_decompositions
14+
from torch.func import functionalize
15+
from typing import List
16+
17+
18+
def default_decompositions():
19+
return get_decompositions(
20+
[
21+
torch.ops.aten.embedding_dense_backward,
22+
torch.ops.aten.native_layer_norm_backward,
23+
torch.ops.aten.slice_backward,
24+
torch.ops.aten.select_backward,
25+
torch.ops.aten.norm.ScalarOpt_dim,
26+
torch.ops.aten.native_group_norm,
27+
torch.ops.aten.upsample_bilinear2d.vec,
28+
torch.ops.aten.split.Tensor,
29+
torch.ops.aten.split_with_sizes,
30+
torch.ops.aten.native_layer_norm,
31+
torch.ops.aten.masked_fill.Tensor,
32+
torch.ops.aten.masked_fill.Scalar,
33+
torch.ops.aten._native_batch_norm_legit_functional,
34+
torch.ops.aten.squeeze.dims,
35+
]
36+
)
37+
38+
1239
def create_backend():
1340
imp = FxImporter()
1441

1542
def import_compiler(gm: GraphModule, example_inputs):
16-
# gm.print_readable()
43+
gm = make_fx(
44+
functionalize(gm),
45+
decomposition_table=default_decompositions(),
46+
)(*example_inputs)
47+
1748
try:
1849
imp.import_graph_module(gm)
1950
finally:
2051
pass
21-
# print(imp.module)
2252
imp.module.operation.verify()
2353
return gm
2454

2555
backend = import_compiler
2656
backend = aot_autograd(fw_compiler=backend)
2757
return backend
2858

59+
2960
def evaluate_importer(nn_cls, get_init_args, get_forward_args, test_identifier):
3061
log = logging.getLogger("turbine-test")
3162
try:

python/test/generated/main.py

Lines changed: 30 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,45 +8,69 @@
88
import torch._inductor.config
99

1010
import logging
11+
1112
log = logging.getLogger("turbine-test")
1213
logging.basicConfig(level=logging.INFO)
1314

1415
ENV_FILE = "JITPARITYBENCH_PATH.txt"
1516

17+
1618
def get_args(raw_args=None):
1719
parser = argparse.ArgumentParser()
18-
parser.add_argument("--jobs", "-j", type=int, default=4, help="Number of threads in our threadpool, jobs=1 is essentially sequential execution")
19-
parser.add_argument("--offset", type=int, default=0, help="Pick files starting from this offset. Together with --limit, we can run through all files in multiple separate runs")
20+
parser.add_argument(
21+
"--jobs",
22+
"-j",
23+
type=int,
24+
default=4,
25+
help="Number of threads in our threadpool, jobs=1 is essentially sequential execution",
26+
)
27+
parser.add_argument(
28+
"--offset",
29+
type=int,
30+
default=0,
31+
help="Pick files starting from this offset. Together with --limit, we can run through all files in multiple separate runs",
32+
)
2033
parser.add_argument("--limit", "-l", type=int, help="only run the first N files")
21-
parser.add_argument("--filter", "-f", "-k", help="only run module containing given name")
34+
parser.add_argument(
35+
"--filter", "-f", "-k", help="only run module containing given name"
36+
)
2237
parser.add_argument("--skips", type=str)
23-
parser.add_argument("--tests-dir", default=None, help="jit-paritybench location (i.e. /path/to/pytorch-jit-paritybench)")
38+
parser.add_argument(
39+
"--tests-dir",
40+
default=None,
41+
help="jit-paritybench location (i.e. /path/to/pytorch-jit-paritybench)",
42+
)
2443
# parser.add_argument("--device", default="cuda", type=str, help="evaluate modules using cuda or cpu") # excluded for now as we only have turbine-cpu, can use this later
2544

2645
args = parser.parse_args(raw_args)
2746
return args
2847

48+
2949
def write_path(path: str):
3050
with open(ENV_FILE, "w") as f:
3151
f.write(path)
3252

53+
3354
def read_path() -> str:
3455
with open(ENV_FILE, "r") as f:
3556
path = f.read()
3657
return path
3758

59+
3860
if __name__ == "__main__":
3961
args = get_args()
4062

4163
if args.tests_dir is not None:
4264
pb = args.tests_dir
43-
write_path(pb) # store this path for next time
65+
write_path(pb) # store this path for next time
4466
log.info(f"Using test directory from CLI: {pb}")
4567
elif os.path.exists(ENV_FILE):
4668
pb = read_path()
4769
log.info(f"Using test directory from {ENV_FILE}: {pb}")
4870
else:
49-
raise RuntimeError(f"Must either pass 'tests-dir' or set {ENV_FILE} in order to run tests")
71+
raise RuntimeError(
72+
f"Must either pass 'tests-dir' or set {ENV_FILE} in order to run tests"
73+
)
5074

5175
# enables finding necessary modules in jit-paritybench
5276
pb_gen = pb + "/generated"

0 commit comments

Comments
 (0)