Skip to content

Commit 38df9fb

Browse files
committed
Address some linting issues
1 parent 7ebbeac commit 38df9fb

File tree

2 files changed

+20
-8
lines changed

2 files changed

+20
-8
lines changed

src/qonnx/transformation/general.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -405,7 +405,6 @@ def apply(self, model): # noqa
405405
# Length of sorted and original input list must match
406406
assert len(inputs) == len(node.input)
407407
# Reassigned inputs from sorted categories
408-
# Note: ONNX does not allow direct assignment to node.input
409408
for i, name in enumerate(inputs):
410409
# The graph has been modified if any input is reordered
411410
if node.input[i] != name:

tests/transformation/test_sort_commutative_inputs_initializer_last.py

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,34 @@
11
# Set pytest parameters
22
import pytest
3+
34
# Numpy for handling simulation of tensor operations
45
import numpy as np
6+
57
# Helper for creating ONNX nodes
68
from onnx import TensorProto
79
from onnx import helper as oh
10+
811
# QONNX wrapper of ONNX model graphs
912
from qonnx.core.modelwrapper import ModelWrapper
10-
# QONNX utility for creating models from ONNX graphs
11-
from qonnx.util.basic import qonnx_make_model
13+
1214
# Execute QONNX model graphs
1315
from qonnx.core.onnx_exec import execute_onnx
16+
1417
# Graph transformation to be tested: Sorts the input list of commutative
1518
# operations to have all dynamic inputs first followed by all initializer inputs
1619
from qonnx.transformation.general import SortCommutativeInputsInitializerLast
1720

21+
# QONNX utility for creating models from ONNX graphs
22+
from qonnx.util.basic import qonnx_make_model
23+
1824

1925
# Specify how many inputs the test should cover
2026
@pytest.mark.parametrize("num_inputs", [4, 5, 6])
2127
# Specify which inputs should be turned into initializers
2228
@pytest.mark.parametrize(
29+
# fmt: off
2330
"initializers", [[], [0], [1], [0, 1], [0, 3], [0, 1, 2, 3]]
31+
# fmt: on
2432
)
2533
# Tests the SortCommutativeInputsInitializerLast transformation
2634
def test_sort_commutative_inputs_initializer_last(num_inputs, initializers):
@@ -29,11 +37,15 @@ def test_sort_commutative_inputs_initializer_last(num_inputs, initializers):
2937
# We will use the Sum ONNX operation to test this behavior, as it allows for
3038
# arbitrary many inputs
3139
node = oh.make_node(
40+
# fmt: off
3241
op_type="Sum", inputs=inputs, outputs=["out"], name="Sum"
42+
# fmt: on
3343
)
3444
# Create value infos for all input and the output tensor
3545
inputs = [
46+
# fmt: off
3647
oh.make_tensor_value_info(i, TensorProto.FLOAT, (16,)) for i in inputs
48+
# fmt: on
3749
]
3850
out = oh.make_tensor_value_info("out", TensorProto.FLOAT, (16,))
3951
# Make a graph comprising the Sum node and value infos for all inputs and
@@ -42,9 +54,7 @@ def test_sort_commutative_inputs_initializer_last(num_inputs, initializers):
4254
# Wrap the graph in an QONNX model wrapper
4355
model = ModelWrapper(qonnx_make_model(graph, producer_name="qonnx-tests"))
4456
# Prepare the execution context
45-
context = {
46-
f"in{i}": np.random.rand(16) for i in range(num_inputs)
47-
}
57+
context = {f"in{i}": np.random.rand(16) for i in range(num_inputs)}
4858
# Make sure all inputs are of type float32
4959
context = {key: value.astype(np.float32) for key, value in context.items()}
5060
# Turn selected inputs into initializers
@@ -57,7 +67,9 @@ def test_sort_commutative_inputs_initializer_last(num_inputs, initializers):
5767
# Note: No cleanup, as the tested transformation is part of the cleanup, and
5868
# we want to test this in isolation
5969
model = model.transform(
70+
# fmt: off
6071
SortCommutativeInputsInitializerLast(), cleanup=False
72+
# fmt: on
6173
)
6274
# Execute the ONNX model after transforming
6375
out_produced = execute_onnx(model, context)["out"]
@@ -71,8 +83,9 @@ def test_sort_commutative_inputs_initializer_last(num_inputs, initializers):
7183
seen_initializer = True
7284
# If there has already been an initializer, this input must be an
7385
# initializer as well
74-
assert not seen_initializer or model.get_initializer(i) is not None, \
75-
"Non-initializer input following initializer after sorting"
86+
assert (
87+
not seen_initializer or model.get_initializer(i) is not None
88+
), "Non-initializer input following initializer after sorting"
7689

7790
# Outputs before and after must match
7891
assert np.allclose(out_produced, out_expected)

0 commit comments

Comments
 (0)