1
1
# Set pytest parameters
2
2
import pytest
3
+
3
4
# Numpy for handling simulation of tensor operations
4
5
import numpy as np
6
+
5
7
# Helper for creating ONNX nodes
6
8
from onnx import TensorProto
7
9
from onnx import helper as oh
10
+
8
11
# QONNX wrapper of ONNX model graphs
9
12
from qonnx .core .modelwrapper import ModelWrapper
10
- # QONNX utility for creating models from ONNX graphs
11
- from qonnx .util .basic import qonnx_make_model
13
+
12
14
# Execute QONNX model graphs
13
15
from qonnx .core .onnx_exec import execute_onnx
16
+
14
17
# Graph transformation to be tested: Sorts the input list of commutative
15
18
# operations to have all dynamic inputs first followed by all initializer inputs
16
19
from qonnx .transformation .general import SortCommutativeInputsInitializerLast
17
20
21
+ # QONNX utility for creating models from ONNX graphs
22
+ from qonnx .util .basic import qonnx_make_model
23
+
18
24
19
25
# Specify how many inputs the test should cover
20
26
@pytest .mark .parametrize ("num_inputs" , [4 , 5 , 6 ])
21
27
# Specify which inputs should be turned into initializers
22
28
@pytest .mark .parametrize (
29
+ # fmt: off
23
30
"initializers" , [[], [0 ], [1 ], [0 , 1 ], [0 , 3 ], [0 , 1 , 2 , 3 ]]
31
+ # fmt: on
24
32
)
25
33
# Tests the SortCommutativeInputsInitializerLast transformation
26
34
def test_sort_commutative_inputs_initializer_last (num_inputs , initializers ):
@@ -29,11 +37,15 @@ def test_sort_commutative_inputs_initializer_last(num_inputs, initializers):
29
37
# We will use the Sum ONNX operation to test this behavior, as it allows for
30
38
# arbitrary many inputs
31
39
node = oh .make_node (
40
+ # fmt: off
32
41
op_type = "Sum" , inputs = inputs , outputs = ["out" ], name = "Sum"
42
+ # fmt: on
33
43
)
34
44
# Create value infos for all input and the output tensor
35
45
inputs = [
46
+ # fmt: off
36
47
oh .make_tensor_value_info (i , TensorProto .FLOAT , (16 ,)) for i in inputs
48
+ # fmt: on
37
49
]
38
50
out = oh .make_tensor_value_info ("out" , TensorProto .FLOAT , (16 ,))
39
51
# Make a graph comprising the Sum node and value infos for all inputs and
@@ -42,9 +54,7 @@ def test_sort_commutative_inputs_initializer_last(num_inputs, initializers):
42
54
# Wrap the graph in an QONNX model wrapper
43
55
model = ModelWrapper (qonnx_make_model (graph , producer_name = "qonnx-tests" ))
44
56
# Prepare the execution context
45
- context = {
46
- f"in{ i } " : np .random .rand (16 ) for i in range (num_inputs )
47
- }
57
+ context = {f"in{ i } " : np .random .rand (16 ) for i in range (num_inputs )}
48
58
# Make sure all inputs are of type float32
49
59
context = {key : value .astype (np .float32 ) for key , value in context .items ()}
50
60
# Turn selected inputs into initializers
@@ -57,7 +67,9 @@ def test_sort_commutative_inputs_initializer_last(num_inputs, initializers):
57
67
# Note: No cleanup, as the tested transformation is part of the cleanup, and
58
68
# we want to test this in isolation
59
69
model = model .transform (
70
+ # fmt: off
60
71
SortCommutativeInputsInitializerLast (), cleanup = False
72
+ # fmt: on
61
73
)
62
74
# Execute the ONNX model after transforming
63
75
out_produced = execute_onnx (model , context )["out" ]
@@ -71,8 +83,9 @@ def test_sort_commutative_inputs_initializer_last(num_inputs, initializers):
71
83
seen_initializer = True
72
84
# If there has already been an initializer, this input must be an
73
85
# initializer as well
74
- assert not seen_initializer or model .get_initializer (i ) is not None , \
75
- "Non-initializer input following initializer after sorting"
86
+ assert (
87
+ not seen_initializer or model .get_initializer (i ) is not None
88
+ ), "Non-initializer input following initializer after sorting"
76
89
77
90
# Outputs before and after must match
78
91
assert np .allclose (out_produced , out_expected )
0 commit comments