Skip to content

Commit 8bba76b

Browse files
committed
Merge branch 'feature/sort_commutative_inputs' of https://github.com/iksnagreb/qonnx into iksnagreb-feature/sort_commutative_inputs
2 parents fe4aa37 + c0f5b46 commit 8bba76b

File tree

2 files changed

+64
-1
lines changed

2 files changed

+64
-1
lines changed

src/qonnx/core/modelwrapper.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,12 @@
3838
import qonnx.util.onnx as onnxutil
3939
from qonnx.core.datatype import DataType
4040
from qonnx.transformation.double_to_single_float import DoubleToSingleFloat
41-
from qonnx.transformation.general import RemoveStaticGraphInputs, RemoveUnusedTensors, SortGraph
41+
from qonnx.transformation.general import (
42+
RemoveStaticGraphInputs,
43+
RemoveUnusedTensors,
44+
SortCommutativeInputsInitializerLast,
45+
SortGraph,
46+
)
4247

4348

4449
class ModelWrapper:
@@ -149,6 +154,7 @@ def cleanup(self):
149154
RemoveUnusedTensors(),
150155
RemoveStaticGraphInputs(),
151156
SortGraph(),
157+
SortCommutativeInputsInitializerLast(),
152158
]
153159
for trn in cleanup_transforms:
154160
transformed_model = transformed_model.transform(trn, cleanup=False, make_deepcopy=False)

src/qonnx/transformation/general.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,9 @@
2929
import json
3030
import numpy as np
3131
import warnings
32+
33+
# Protobuf onnx graph node type
34+
from onnx import NodeProto # noqa
3235
from onnx import mapping
3336
from toposort import toposort_flatten
3437

@@ -359,3 +362,57 @@ def apply(self, model):
359362

360363
# one iteration is enough
361364
return (model, False)
365+
366+
367+
# Groups inputs by categories, i.e., groups dynamic inputs first, followed by
368+
# initializers. Keeps order of inputs in each category.
369+
def group_inputs_by_category(node: NodeProto, model): # noqa
370+
# Select all dynamic inputs, which are those without initializer tensor
371+
dynamics = [i for i in node.input if model.get_initializer(i) is None]
372+
# Select all input which are initializers, which, by exclusion, are all
373+
# those not among the dynamic inputs
374+
initializers = [i for i in node.input if i not in dynamics]
375+
# Return lists of dynamic anc initializer inputs
376+
return dynamics, initializers
377+
378+
379+
# Tidy-Up transformation sorting the inputs to all commutative operations to
380+
# have initializer inputs last
381+
class SortCommutativeInputsInitializerLast(Transformation):
382+
"""
383+
Sorts inputs of nodes describing commutative operations to have initializer
384+
inputs last. This order of inputs is assumed by many other transformations.
385+
"""
386+
387+
# Set of supported commutative operations
388+
# TODO: There might be more valid operations
389+
SUPPORTED_COMMUTATIVE_OPS = {"Add", "Mul", "And", "Or", "Xor", "Sum"}
390+
391+
# Applies the transform to a whole model graph
392+
def apply(self, model): # noqa
393+
# Get the model graph out of the model wrapper object
394+
graph = model.graph
395+
# Keep track of whether the graph has been modified
396+
graph_modified = False
397+
# Iterate all nodes in the graph keeping track of the index
398+
for index, node in enumerate(graph.node):
399+
# Check whether this node is among the supported
400+
if node.op_type in self.SUPPORTED_COMMUTATIVE_OPS:
401+
# Group node inputs by category
402+
dynamics, initializers = group_inputs_by_category(node, model)
403+
# Flatten the grouped input list
404+
inputs = [*dynamics, *initializers]
405+
# Length of sorted and original input list must match
406+
assert len(inputs) == len(node.input)
407+
# Reassigned inputs from sorted categories
408+
# Note: ONNX does not allow direct assignment to node.input
409+
for i, name in enumerate(inputs):
410+
# The graph has been modified if any input is reordered
411+
if node.input[i] != name:
412+
# Note: This is never reset back to False
413+
graph_modified = True
414+
# Reassign input name at the new index
415+
node.input[i] = name
416+
# Return the transformed model and indicate whether the graph actually
417+
# has been transformed
418+
return model, graph_modified

0 commit comments

Comments
 (0)