|
29 | 29 | import json
|
30 | 30 | import numpy as np
|
31 | 31 | import warnings
|
| 32 | + |
| 33 | +# Protobuf onnx graph node type |
| 34 | +from onnx import NodeProto # noqa |
32 | 35 | from onnx import mapping
|
33 | 36 | from toposort import toposort_flatten
|
34 | 37 |
|
@@ -359,3 +362,57 @@ def apply(self, model):
|
359 | 362 |
|
360 | 363 | # one iteration is enough
|
361 | 364 | return (model, False)
|
| 365 | + |
| 366 | + |
| 367 | +# Groups inputs by categories, i.e., groups dynamic inputs first, followed by |
| 368 | +# initializers. Keeps order of inputs in each category. |
| 369 | +def group_inputs_by_category(node: NodeProto, model): # noqa |
| 370 | + # Select all dynamic inputs, which are those without initializer tensor |
| 371 | + dynamics = [i for i in node.input if model.get_initializer(i) is None] |
| 372 | + # Select all input which are initializers, which, by exclusion, are all |
| 373 | + # those not among the dynamic inputs |
| 374 | + initializers = [i for i in node.input if i not in dynamics] |
| 375 | + # Return lists of dynamic anc initializer inputs |
| 376 | + return dynamics, initializers |
| 377 | + |
| 378 | + |
| 379 | +# Tidy-Up transformation sorting the inputs to all commutative operations to |
| 380 | +# have initializer inputs last |
| 381 | +class SortCommutativeInputsInitializerLast(Transformation): |
| 382 | + """ |
| 383 | + Sorts inputs of nodes describing commutative operations to have initializer |
| 384 | + inputs last. This order of inputs is assumed by many other transformations. |
| 385 | + """ |
| 386 | + |
| 387 | + # Set of supported commutative operations |
| 388 | + # TODO: There might be more valid operations |
| 389 | + SUPPORTED_COMMUTATIVE_OPS = {"Add", "Mul", "And", "Or", "Xor", "Sum"} |
| 390 | + |
| 391 | + # Applies the transform to a whole model graph |
| 392 | + def apply(self, model): # noqa |
| 393 | + # Get the model graph out of the model wrapper object |
| 394 | + graph = model.graph |
| 395 | + # Keep track of whether the graph has been modified |
| 396 | + graph_modified = False |
| 397 | + # Iterate all nodes in the graph keeping track of the index |
| 398 | + for index, node in enumerate(graph.node): |
| 399 | + # Check whether this node is among the supported |
| 400 | + if node.op_type in self.SUPPORTED_COMMUTATIVE_OPS: |
| 401 | + # Group node inputs by category |
| 402 | + dynamics, initializers = group_inputs_by_category(node, model) |
| 403 | + # Flatten the grouped input list |
| 404 | + inputs = [*dynamics, *initializers] |
| 405 | + # Length of sorted and original input list must match |
| 406 | + assert len(inputs) == len(node.input) |
| 407 | + # Reassigned inputs from sorted categories |
| 408 | + # Note: ONNX does not allow direct assignment to node.input |
| 409 | + for i, name in enumerate(inputs): |
| 410 | + # The graph has been modified if any input is reordered |
| 411 | + if node.input[i] != name: |
| 412 | + # Note: This is never reset back to False |
| 413 | + graph_modified = True |
| 414 | + # Reassign input name at the new index |
| 415 | + node.input[i] = name |
| 416 | + # Return the transformed model and indicate whether the graph actually |
| 417 | + # has been transformed |
| 418 | + return model, graph_modified |
0 commit comments