Skip to content

Commit 60f8692

Browse files
titaiwangmsjustinchubydependabot[bot]
authored
[pass] Add CSE (#36)
CSE pass from microsoft/onnxscript#2304 --------- Signed-off-by: Ti-Tai Wang <titaiwang@microsoft.com> Signed-off-by: Justin Chu <justinchuby@users.noreply.github.com> Signed-off-by: dependabot[bot] <support@github.com> Co-authored-by: Justin Chu <justinchuby@users.noreply.github.com> Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
1 parent 322205f commit 60f8692

File tree

4 files changed

+509
-0
lines changed

4 files changed

+509
-0
lines changed

noxfile.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,10 @@
2727
"types-PyYAML",
2828
"typing_extensions>=4.10",
2929
"ml-dtypes",
30+
"onnxruntime",
3031
)
3132
ONNX = "onnx==1.18"
33+
ONNXSCRIPT = "onnxscript"
3234
ONNX_RUNTIME = "onnxruntime==1.20.1"
3335
PYTORCH = "torch==2.7.0"
3436
TORCHVISON = "torchvision==0.22.0"
@@ -50,6 +52,7 @@ def test(session):
5052
ONNX,
5153
PYTORCH,
5254
)
55+
session.install(ONNXSCRIPT, "--no-deps")
5356
session.install(".", "--no-deps")
5457
session.run("pip", "list")
5558
session.run("pytest", "src", "--doctest-modules", *session.posargs)
@@ -61,6 +64,7 @@ def test_onnx_weekly(session):
6164
"""Test with ONNX weekly (preview) build."""
6265
session.install(*COMMON_TEST_DEPENDENCIES, PYTORCH)
6366
session.install("-r", "requirements/ci/requirements-onnx-weekly.txt")
67+
session.install(ONNXSCRIPT, "--no-deps")
6468
session.install(".", "--no-deps")
6569
session.run("pip", "list")
6670
session.run("pytest", "src", "--doctest-modules", *session.posargs)
@@ -73,6 +77,7 @@ def test_torch_nightly(session):
7377
session.install(*COMMON_TEST_DEPENDENCIES)
7478
session.install("-r", "requirements/ci/requirements-onnx-weekly.txt")
7579
session.install("-r", "requirements/ci/requirements-pytorch-nightly.txt")
80+
session.install(ONNXSCRIPT, "--no-deps")
7681
session.install(".", "--no-deps")
7782
session.run("pip", "list")
7883
session.run("pytest", "src", "--doctest-modules", *session.posargs)

src/onnx_ir/passes/common/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
"AddInitializersToInputsPass",
66
"CheckerPass",
77
"ClearMetadataAndDocStringPass",
8+
"CommonSubexpressionEliminationPass",
89
"InlinePass",
910
"LiftConstantsToInitializersPass",
1011
"LiftSubgraphInitializersToMainGraphPass",
@@ -19,6 +20,9 @@
1920
from onnx_ir.passes.common.clear_metadata_and_docstring import (
2021
ClearMetadataAndDocStringPass,
2122
)
23+
from onnx_ir.passes.common.common_subexpression_elimination import (
24+
CommonSubexpressionEliminationPass,
25+
)
2226
from onnx_ir.passes.common.constant_manipulation import (
2327
AddInitializersToInputsPass,
2428
LiftConstantsToInitializersPass,
Lines changed: 177 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,177 @@
1+
# Copyright (c) ONNX Project Contributors
2+
# SPDX-License-Identifier: Apache-2.0
3+
"""Eliminate common subexpression in ONNX graphs."""
4+
5+
from __future__ import annotations
6+
7+
__all__ = [
8+
"CommonSubexpressionEliminationPass",
9+
]
10+
11+
import logging
12+
from collections.abc import Sequence
13+
14+
import onnx_ir as ir
15+
16+
logger = logging.getLogger(__name__)
17+
18+
19+
class CommonSubexpressionEliminationPass(ir.passes.InPlacePass):
20+
"""Eliminate common subexpression in ONNX graphs."""
21+
22+
def call(self, model: ir.Model) -> ir.passes.PassResult:
23+
"""Return the same ir.Model but with CSE applied to the graph."""
24+
modified = False
25+
graph = model.graph
26+
27+
modified = _eliminate_common_subexpression(graph, modified)
28+
29+
return ir.passes.PassResult(
30+
model,
31+
modified=modified,
32+
)
33+
34+
35+
def _eliminate_common_subexpression(graph: ir.Graph, modified: bool) -> bool:
36+
"""Eliminate common subexpression in ONNX graphs."""
37+
# node to node identifier, length of outputs, inputs, and attributes
38+
existing_node_info_to_the_node: dict[
39+
tuple[
40+
ir.OperatorIdentifier,
41+
int, # len(outputs)
42+
tuple[int, ...], # input ids
43+
tuple[tuple[str, object], ...], # attributes
44+
],
45+
ir.Node,
46+
] = {}
47+
48+
for node in graph:
49+
# Skip control flow ops like Loop and If.
50+
control_flow_op: bool = False
51+
# Use equality to check if the node is a common subexpression.
52+
attributes = {}
53+
for k, v in node.attributes.items():
54+
# TODO(exporter team): CSE subgraphs.
55+
# NOTE: control flow ops like Loop and If won't be CSEd
56+
# because attribute: graph won't match.
57+
if v.type in (ir.AttributeType.GRAPH, ir.AttributeType.GRAPHS):
58+
control_flow_op = True
59+
logger.debug("Skipping control flow op %s", node)
60+
# The attribute value could be directly taken from the original
61+
# protobuf, so we need to make a copy of it.
62+
value = v.value
63+
if v.type in (
64+
ir.AttributeType.INTS,
65+
ir.AttributeType.FLOATS,
66+
ir.AttributeType.STRINGS,
67+
):
68+
# For INT, FLOAT and STRING attributes, we convert them to tuples
69+
# to ensure they are hashable.
70+
value = tuple(value)
71+
attributes[k] = value
72+
73+
if control_flow_op:
74+
# If the node is a control flow op, we skip it.
75+
logger.debug("Skipping control flow op %s", node)
76+
continue
77+
78+
if _is_non_deterministic_op(node):
79+
# If the node is a non-deterministic op, we skip it.
80+
logger.debug("Skipping non-deterministic op %s", node)
81+
continue
82+
83+
node_info = (
84+
node.op_identifier(),
85+
len(node.outputs),
86+
tuple(id(input) for input in node.inputs),
87+
tuple(sorted(attributes.items())),
88+
)
89+
# Check if the node is a common subexpression.
90+
if node_info in existing_node_info_to_the_node:
91+
# If it is, this node has an existing node with the same
92+
# operator, number of outputs, inputs, and attributes.
93+
# We replace the node with the existing node.
94+
modified = True
95+
existing_node = existing_node_info_to_the_node[node_info]
96+
_remove_node_and_replace_values(
97+
graph,
98+
remove_node=node,
99+
remove_values=node.outputs,
100+
new_values=existing_node.outputs,
101+
)
102+
logger.debug("Reusing node %s", existing_node)
103+
else:
104+
# If it is not, add to the mapping.
105+
existing_node_info_to_the_node[node_info] = node
106+
return modified
107+
108+
109+
def _remove_node_and_replace_values(
110+
graph: ir.Graph,
111+
/,
112+
remove_node: ir.Node,
113+
remove_values: Sequence[ir.Value],
114+
new_values: Sequence[ir.Value],
115+
) -> None:
116+
"""Replaces nodes and values in the graph or function.
117+
118+
Args:
119+
graph: The graph to replace nodes and values in.
120+
remove_node: The node to remove.
121+
remove_values: The values to replace.
122+
new_values: The values to replace with.
123+
"""
124+
# Reconnect the users of the deleted values to use the new values
125+
ir.convenience.replace_all_uses_with(remove_values, new_values)
126+
# Update graph/function outputs if the node generates output
127+
if any(remove_value.is_graph_output() for remove_value in remove_values):
128+
replacement_mapping = dict(zip(remove_values, new_values))
129+
for idx, graph_output in enumerate(graph.outputs):
130+
if graph_output in replacement_mapping:
131+
new_value = replacement_mapping[graph_output]
132+
if new_value.is_graph_output() or new_value.is_graph_input():
133+
# If the new value is also a graph input/output, we need to
134+
# create a Identity node to preserve the remove_value and
135+
# prevent from changing new_value name.
136+
identity_node = ir.node(
137+
"Identity",
138+
inputs=[new_value],
139+
outputs=[
140+
ir.Value(
141+
name=graph_output.name,
142+
type=graph_output.type,
143+
shape=graph_output.shape,
144+
)
145+
],
146+
)
147+
# reuse the name of the graph output
148+
graph.outputs[idx] = identity_node.outputs[0]
149+
graph.insert_before(
150+
remove_node,
151+
identity_node,
152+
)
153+
else:
154+
# if new_value is not graph output, we just
155+
# update it to use old_value name.
156+
new_value.name = graph_output.name
157+
graph.outputs[idx] = new_value
158+
159+
graph.remove(remove_node, safe=True)
160+
161+
162+
def _is_non_deterministic_op(node: ir.Node) -> bool:
163+
non_deterministic_ops = frozenset(
164+
{
165+
"RandomUniform",
166+
"RandomNormal",
167+
"RandomUniformLike",
168+
"RandomNormalLike",
169+
"Multinomial",
170+
}
171+
)
172+
return node.op_type in non_deterministic_ops and _is_onnx_domain(node.domain)
173+
174+
175+
def _is_onnx_domain(d: str) -> bool:
176+
"""Check if the domain is the ONNX domain."""
177+
return d == ""

0 commit comments

Comments
 (0)