Skip to content

Commit e02f701

Browse files
authored
Merge pull request #85 from iksnagreb/feature/sort_commutative_inputs
Add cleanup transformation sorting inputs of commutative operations
2 parents 1f8938a + cc1b1f0 commit e02f701

File tree

3 files changed

+154
-1
lines changed

3 files changed

+154
-1
lines changed

src/qonnx/core/modelwrapper.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,12 @@
3838
import qonnx.util.onnx as onnxutil
3939
from qonnx.core.datatype import DataType
4040
from qonnx.transformation.double_to_single_float import DoubleToSingleFloat
41-
from qonnx.transformation.general import RemoveStaticGraphInputs, RemoveUnusedTensors, SortGraph
41+
from qonnx.transformation.general import (
42+
RemoveStaticGraphInputs,
43+
RemoveUnusedTensors,
44+
SortCommutativeInputsInitializerLast,
45+
SortGraph,
46+
)
4247

4348

4449
class ModelWrapper:
@@ -149,6 +154,7 @@ def cleanup(self):
149154
RemoveUnusedTensors(),
150155
RemoveStaticGraphInputs(),
151156
SortGraph(),
157+
SortCommutativeInputsInitializerLast(),
152158
]
153159
for trn in cleanup_transforms:
154160
transformed_model = transformed_model.transform(trn, cleanup=False, make_deepcopy=False)

src/qonnx/transformation/general.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,9 @@
2929
import json
3030
import numpy as np
3131
import warnings
32+
33+
# Protobuf onnx graph node type
34+
from onnx import NodeProto # noqa
3235
from onnx import mapping
3336
from toposort import toposort_flatten
3437

@@ -359,3 +362,56 @@ def apply(self, model):
359362

360363
# one iteration is enough
361364
return (model, False)
365+
366+
367+
# Groups inputs by categories, i.e., groups dynamic inputs first, followed by
368+
# initializers. Keeps order of inputs in each category.
369+
def group_inputs_by_category(node: NodeProto, model): # noqa
370+
# Select all dynamic inputs, which are those without initializer tensor
371+
dynamics = [i for i in node.input if model.get_initializer(i) is None]
372+
# Select all input which are initializers, which, by exclusion, are all
373+
# those not among the dynamic inputs
374+
initializers = [i for i in node.input if i not in dynamics]
375+
# Return lists of dynamic anc initializer inputs
376+
return dynamics, initializers
377+
378+
379+
# Tidy-Up transformation sorting the inputs to all commutative operations to
380+
# have initializer inputs last
381+
class SortCommutativeInputsInitializerLast(Transformation):
382+
"""
383+
Sorts inputs of nodes describing commutative operations to have initializer
384+
inputs last. This order of inputs is assumed by many other transformations.
385+
"""
386+
387+
# Set of supported commutative operations
388+
# TODO: There might be more valid operations
389+
SUPPORTED_COMMUTATIVE_OPS = {"Add", "Mul", "And", "Or", "Xor", "Sum"}
390+
391+
# Applies the transform to a whole model graph
392+
def apply(self, model): # noqa
393+
# Get the model graph out of the model wrapper object
394+
graph = model.graph
395+
# Keep track of whether the graph has been modified
396+
graph_modified = False
397+
# Iterate all nodes in the graph keeping track of the index
398+
for index, node in enumerate(graph.node):
399+
# Check whether this node is among the supported
400+
if node.op_type in self.SUPPORTED_COMMUTATIVE_OPS:
401+
# Group node inputs by category
402+
dynamics, initializers = group_inputs_by_category(node, model)
403+
# Flatten the grouped input list
404+
inputs = [*dynamics, *initializers]
405+
# Length of sorted and original input list must match
406+
assert len(inputs) == len(node.input)
407+
# Reassigned inputs from sorted categories
408+
for i, name in enumerate(inputs):
409+
# The graph has been modified if any input is reordered
410+
if node.input[i] != name:
411+
# Note: This is never reset back to False
412+
graph_modified = True
413+
# Reassign input name at the new index
414+
node.input[i] = name
415+
# Return the transformed model and indicate whether the graph actually
416+
# has been transformed
417+
return model, graph_modified
Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
# Set pytest parameters
2+
import pytest
3+
4+
# Numpy for handling simulation of tensor operations
5+
import numpy as np
6+
7+
# Helper for creating ONNX nodes
8+
from onnx import TensorProto
9+
from onnx import helper as oh
10+
11+
# QONNX wrapper of ONNX model graphs
12+
from qonnx.core.modelwrapper import ModelWrapper
13+
14+
# Execute QONNX model graphs
15+
from qonnx.core.onnx_exec import execute_onnx
16+
17+
# Graph transformation to be tested: Sorts the input list of commutative
18+
# operations to have all dynamic inputs first followed by all initializer inputs
19+
from qonnx.transformation.general import SortCommutativeInputsInitializerLast
20+
21+
# QONNX utility for creating models from ONNX graphs
22+
from qonnx.util.basic import qonnx_make_model
23+
24+
25+
# Specify how many inputs the test should cover
26+
@pytest.mark.parametrize("num_inputs", [4, 5, 6])
27+
# Specify which inputs should be turned into initializers
28+
@pytest.mark.parametrize(
29+
# fmt: off
30+
"initializers", [[], [0], [1], [0, 1], [0, 3], [0, 1, 2, 3]]
31+
# fmt: on
32+
)
33+
# Tests the SortCommutativeInputsInitializerLast transformation
34+
def test_sort_commutative_inputs_initializer_last(num_inputs, initializers):
35+
# Generate the input tensor names
36+
inputs = [f"in{i}" for i in range(num_inputs)]
37+
# We will use the Sum ONNX operation to test this behavior, as it allows for
38+
# arbitrary many inputs
39+
node = oh.make_node(
40+
# fmt: off
41+
op_type="Sum", inputs=inputs, outputs=["out"], name="Sum"
42+
# fmt: on
43+
)
44+
# Create value infos for all input and the output tensor
45+
inputs = [
46+
# fmt: off
47+
oh.make_tensor_value_info(i, TensorProto.FLOAT, (16,)) for i in inputs
48+
# fmt: on
49+
]
50+
out = oh.make_tensor_value_info("out", TensorProto.FLOAT, (16,))
51+
# Make a graph comprising the Sum node and value infos for all inputs and
52+
# the output
53+
graph = oh.make_graph([node], inputs=inputs, outputs=[out], name="Sum")
54+
# Wrap the graph in an QONNX model wrapper
55+
model = ModelWrapper(qonnx_make_model(graph, producer_name="qonnx-tests"))
56+
# Prepare the execution context
57+
context = {f"in{i}": np.random.rand(16) for i in range(num_inputs)}
58+
# Make sure all inputs are of type float32
59+
context = {key: value.astype(np.float32) for key, value in context.items()}
60+
# Turn selected inputs into initializers
61+
for i in initializers:
62+
model.set_initializer(f"in{i}", context[f"in{i}"])
63+
64+
# Execute the ONNX model before transforming
65+
out_expected = execute_onnx(model, context)["out"]
66+
# Apply the transformation to be tested
67+
# Note: No cleanup, as the tested transformation is part of the cleanup, and
68+
# we want to test this in isolation
69+
model = model.transform(
70+
# fmt: off
71+
SortCommutativeInputsInitializerLast(), cleanup=False
72+
# fmt: on
73+
)
74+
# Execute the ONNX model after transforming
75+
out_produced = execute_onnx(model, context)["out"]
76+
77+
# Start with no initializer input seen so far
78+
seen_initializer = False
79+
# Verify that no "dynamic" input follows an initializer input
80+
for i in model.graph.node[0].input:
81+
# Keep track of when an initializer has been seen
82+
if model.get_initializer(i) is not None:
83+
seen_initializer = True
84+
# If there has already been an initializer, this input must be an
85+
# initializer as well
86+
assert (
87+
not seen_initializer or model.get_initializer(i) is not None
88+
), "Non-initializer input following initializer after sorting"
89+
90+
# Outputs before and after must match
91+
assert np.allclose(out_produced, out_expected)

0 commit comments

Comments
 (0)