Skip to content

Commit 7ebbeac

Browse files
committed
Add unit test for SortCommutativeInputsInitializerLast transformation
1 parent 8bba76b commit 7ebbeac

File tree

1 file changed

+78
-0
lines changed

1 file changed

+78
-0
lines changed
Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
# Set pytest parameters
2+
import pytest
3+
# Numpy for handling simulation of tensor operations
4+
import numpy as np
5+
# Helper for creating ONNX nodes
6+
from onnx import TensorProto
7+
from onnx import helper as oh
8+
# QONNX wrapper of ONNX model graphs
9+
from qonnx.core.modelwrapper import ModelWrapper
10+
# QONNX utility for creating models from ONNX graphs
11+
from qonnx.util.basic import qonnx_make_model
12+
# Execute QONNX model graphs
13+
from qonnx.core.onnx_exec import execute_onnx
14+
# Graph transformation to be tested: Sorts the input list of commutative
15+
# operations to have all dynamic inputs first followed by all initializer inputs
16+
from qonnx.transformation.general import SortCommutativeInputsInitializerLast
17+
18+
19+
# Specify how many inputs the test should cover
20+
@pytest.mark.parametrize("num_inputs", [4, 5, 6])
21+
# Specify which inputs should be turned into initializers
22+
@pytest.mark.parametrize(
23+
"initializers", [[], [0], [1], [0, 1], [0, 3], [0, 1, 2, 3]]
24+
)
25+
# Tests the SortCommutativeInputsInitializerLast transformation
26+
def test_sort_commutative_inputs_initializer_last(num_inputs, initializers):
27+
# Generate the input tensor names
28+
inputs = [f"in{i}" for i in range(num_inputs)]
29+
# We will use the Sum ONNX operation to test this behavior, as it allows for
30+
# arbitrary many inputs
31+
node = oh.make_node(
32+
op_type="Sum", inputs=inputs, outputs=["out"], name="Sum"
33+
)
34+
# Create value infos for all input and the output tensor
35+
inputs = [
36+
oh.make_tensor_value_info(i, TensorProto.FLOAT, (16,)) for i in inputs
37+
]
38+
out = oh.make_tensor_value_info("out", TensorProto.FLOAT, (16,))
39+
# Make a graph comprising the Sum node and value infos for all inputs and
40+
# the output
41+
graph = oh.make_graph([node], inputs=inputs, outputs=[out], name="Sum")
42+
# Wrap the graph in an QONNX model wrapper
43+
model = ModelWrapper(qonnx_make_model(graph, producer_name="qonnx-tests"))
44+
# Prepare the execution context
45+
context = {
46+
f"in{i}": np.random.rand(16) for i in range(num_inputs)
47+
}
48+
# Make sure all inputs are of type float32
49+
context = {key: value.astype(np.float32) for key, value in context.items()}
50+
# Turn selected inputs into initializers
51+
for i in initializers:
52+
model.set_initializer(f"in{i}", context[f"in{i}"])
53+
54+
# Execute the ONNX model before transforming
55+
out_expected = execute_onnx(model, context)["out"]
56+
# Apply the transformation to be tested
57+
# Note: No cleanup, as the tested transformation is part of the cleanup, and
58+
# we want to test this in isolation
59+
model = model.transform(
60+
SortCommutativeInputsInitializerLast(), cleanup=False
61+
)
62+
# Execute the ONNX model after transforming
63+
out_produced = execute_onnx(model, context)["out"]
64+
65+
# Start with no initializer input seen so far
66+
seen_initializer = False
67+
# Verify that no "dynamic" input follows an initializer input
68+
for i in model.graph.node[0].input:
69+
# Keep track of when an initializer has been seen
70+
if model.get_initializer(i) is not None:
71+
seen_initializer = True
72+
# If there has already been an initializer, this input must be an
73+
# initializer as well
74+
assert not seen_initializer or model.get_initializer(i) is not None, \
75+
"Non-initializer input following initializer after sorting"
76+
77+
# Outputs before and after must match
78+
assert np.allclose(out_produced, out_expected)

0 commit comments

Comments
 (0)