|
| 1 | +# Set pytest parameters |
| 2 | +import pytest |
| 3 | + |
| 4 | +# Numpy for handling simulation of tensor operations |
| 5 | +import numpy as np |
| 6 | + |
| 7 | +# Helper for creating ONNX nodes |
| 8 | +from onnx import TensorProto |
| 9 | +from onnx import helper as oh |
| 10 | + |
| 11 | +# QONNX wrapper of ONNX model graphs |
| 12 | +from qonnx.core.modelwrapper import ModelWrapper |
| 13 | + |
| 14 | +# Execute QONNX model graphs |
| 15 | +from qonnx.core.onnx_exec import execute_onnx |
| 16 | + |
| 17 | +# Graph transformation to be tested: Sorts the input list of commutative |
| 18 | +# operations to have all dynamic inputs first followed by all initializer inputs |
| 19 | +from qonnx.transformation.general import SortCommutativeInputsInitializerLast |
| 20 | + |
| 21 | +# QONNX utility for creating models from ONNX graphs |
| 22 | +from qonnx.util.basic import qonnx_make_model |
| 23 | + |
| 24 | + |
| 25 | +# Specify how many inputs the test should cover |
| 26 | +@pytest.mark.parametrize("num_inputs", [4, 5, 6]) |
| 27 | +# Specify which inputs should be turned into initializers |
| 28 | +@pytest.mark.parametrize( |
| 29 | + # fmt: off |
| 30 | + "initializers", [[], [0], [1], [0, 1], [0, 3], [0, 1, 2, 3]] |
| 31 | + # fmt: on |
| 32 | +) |
| 33 | +# Tests the SortCommutativeInputsInitializerLast transformation |
| 34 | +def test_sort_commutative_inputs_initializer_last(num_inputs, initializers): |
| 35 | + # Generate the input tensor names |
| 36 | + inputs = [f"in{i}" for i in range(num_inputs)] |
| 37 | + # We will use the Sum ONNX operation to test this behavior, as it allows for |
| 38 | + # arbitrary many inputs |
| 39 | + node = oh.make_node( |
| 40 | + # fmt: off |
| 41 | + op_type="Sum", inputs=inputs, outputs=["out"], name="Sum" |
| 42 | + # fmt: on |
| 43 | + ) |
| 44 | + # Create value infos for all input and the output tensor |
| 45 | + inputs = [ |
| 46 | + # fmt: off |
| 47 | + oh.make_tensor_value_info(i, TensorProto.FLOAT, (16,)) for i in inputs |
| 48 | + # fmt: on |
| 49 | + ] |
| 50 | + out = oh.make_tensor_value_info("out", TensorProto.FLOAT, (16,)) |
| 51 | + # Make a graph comprising the Sum node and value infos for all inputs and |
| 52 | + # the output |
| 53 | + graph = oh.make_graph([node], inputs=inputs, outputs=[out], name="Sum") |
| 54 | + # Wrap the graph in an QONNX model wrapper |
| 55 | + model = ModelWrapper(qonnx_make_model(graph, producer_name="qonnx-tests")) |
| 56 | + # Prepare the execution context |
| 57 | + context = {f"in{i}": np.random.rand(16) for i in range(num_inputs)} |
| 58 | + # Make sure all inputs are of type float32 |
| 59 | + context = {key: value.astype(np.float32) for key, value in context.items()} |
| 60 | + # Turn selected inputs into initializers |
| 61 | + for i in initializers: |
| 62 | + model.set_initializer(f"in{i}", context[f"in{i}"]) |
| 63 | + |
| 64 | + # Execute the ONNX model before transforming |
| 65 | + out_expected = execute_onnx(model, context)["out"] |
| 66 | + # Apply the transformation to be tested |
| 67 | + # Note: No cleanup, as the tested transformation is part of the cleanup, and |
| 68 | + # we want to test this in isolation |
| 69 | + model = model.transform( |
| 70 | + # fmt: off |
| 71 | + SortCommutativeInputsInitializerLast(), cleanup=False |
| 72 | + # fmt: on |
| 73 | + ) |
| 74 | + # Execute the ONNX model after transforming |
| 75 | + out_produced = execute_onnx(model, context)["out"] |
| 76 | + |
| 77 | + # Start with no initializer input seen so far |
| 78 | + seen_initializer = False |
| 79 | + # Verify that no "dynamic" input follows an initializer input |
| 80 | + for i in model.graph.node[0].input: |
| 81 | + # Keep track of when an initializer has been seen |
| 82 | + if model.get_initializer(i) is not None: |
| 83 | + seen_initializer = True |
| 84 | + # If there has already been an initializer, this input must be an |
| 85 | + # initializer as well |
| 86 | + assert ( |
| 87 | + not seen_initializer or model.get_initializer(i) is not None |
| 88 | + ), "Non-initializer input following initializer after sorting" |
| 89 | + |
| 90 | + # Outputs before and after must match |
| 91 | + assert np.allclose(out_produced, out_expected) |
0 commit comments