Skip to content

Commit 7598ff9

Browse files
committed
[Lint] run pre-commit for PR #125
1 parent 5cc8da7 commit 7598ff9

File tree

2 files changed

+22
-18
lines changed

2 files changed

+22
-18
lines changed

src/qonnx/transformation/change_3d_tensors_to_4d.py

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -196,12 +196,10 @@ def apply(self, model):
196196
scales = np.append(scales, np.asarray(1.0, dtype=np.float32))
197197
model.set_initializer(n.input[1], scales)
198198
elif node_op_type == "Resize":
199-
assert ("axes" not in [x.name for x in n.attribute]), (
200-
"%s: Axes attribute is not supported." % n.name
201-
)
202-
assert (not (len(n.input) in (3, 4) and model.get_initializer(n.input[1]) is not None)), (
203-
"%s: ROI input is not supported." % n.name
204-
)
199+
assert "axes" not in [x.name for x in n.attribute], "%s: Axes attribute is not supported." % n.name
200+
assert not (len(n.input) in (3, 4) and model.get_initializer(n.input[1]) is not None), (
201+
"%s: ROI input is not supported." % n.name
202+
)
205203
if len(n.input) == 2:
206204
# Resize version 10
207205
scales = model.get_initializer(n.input[1])
@@ -213,13 +211,17 @@ def apply(self, model):
213211
scales = np.append(scales, np.asarray(1.0, dtype=np.float32))
214212
model.set_initializer(n.input[2], scales)
215213
elif len(n.input) == 4:
216-
scales_exists = (model.get_initializer(n.input[2]) is not None) and (len(model.get_initializer(n.input[2])) != 0)
217-
sizes_exists = (model.get_initializer(n.input[3]) is not None) and (len(model.get_initializer(n.input[3])) != 0)
218-
assert (scales_exists ^ sizes_exists), (
219-
"%s: Either scales or the target output size must "
214+
scales_exists = (model.get_initializer(n.input[2]) is not None) and (
215+
len(model.get_initializer(n.input[2])) != 0
216+
)
217+
sizes_exists = (model.get_initializer(n.input[3]) is not None) and (
218+
len(model.get_initializer(n.input[3])) != 0
219+
)
220+
assert scales_exists ^ sizes_exists, (
221+
"%s: Either scales or the target output size must "
220222
"be specified. Specifying both is prohibited." % n.name
221-
)
222-
if (scales_exists):
223+
)
224+
if scales_exists:
223225
# Scales parameter is a 1d list of upsampling factors along each axis
224226
scales = model.get_initializer(n.input[2])
225227
scales = np.append(scales, np.asarray(1.0, dtype=np.float32))

tests/transformation/test_4d_conversion.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -327,6 +327,7 @@ def create_conv_upsample():
327327
model.set_initializer(tensor_name, gen_finn_dt_tensor(DataType["FLOAT32"], init_shape))
328328
return model
329329

330+
330331
def create_resize():
331332
"""
332333
Creates an model for testing the 3D to 4D transform of the resize node.
@@ -346,18 +347,18 @@ def create_resize():
346347
name="Resize2",
347348
mode="nearest",
348349
)
349-
350+
350351
in_resize1 = onnx.helper.make_tensor_value_info("in_resize1", onnx.TensorProto.FLOAT, [1, 32, 4])
351-
out_resize1 = onnx.helper.make_tensor_value_info("out_resize1", onnx.TensorProto.FLOAT, [1, 32, 8])
352-
out_resize2 = onnx.helper.make_tensor_value_info("out_resize2", onnx.TensorProto.FLOAT, [1, 32, 16])
353-
352+
out_resize1 = onnx.helper.make_tensor_value_info("out_resize1", onnx.TensorProto.FLOAT, [1, 32, 8])
353+
out_resize2 = onnx.helper.make_tensor_value_info("out_resize2", onnx.TensorProto.FLOAT, [1, 32, 16])
354+
354355
roi_resize1 = onnx.helper.make_tensor_value_info("roi_resize1", onnx.TensorProto.FLOAT, [4])
355356
scales_resize1 = onnx.helper.make_tensor_value_info("scales_resize1", onnx.TensorProto.FLOAT, [])
356357
sizes_resize1 = onnx.helper.make_tensor_value_info("sizes_resize1", onnx.TensorProto.INT64, [3])
357358

358359
roi_resize2 = onnx.helper.make_tensor_value_info("roi_resize2", onnx.TensorProto.FLOAT, [4])
359360
scales_resize2 = onnx.helper.make_tensor_value_info("scales_resize2", onnx.TensorProto.FLOAT, [3])
360-
361+
361362
list_of_nodes = [
362363
resize_node1,
363364
resize_node2,
@@ -384,9 +385,10 @@ def create_resize():
384385
model = model.transform(InferShapes())
385386
model.set_initializer("sizes_resize1", np.array([1, 32, 8], dtype=np.int64))
386387
model.set_initializer("scales_resize1", np.array([], dtype=np.float32))
387-
model.set_initializer("scales_resize2", np.array([1., 1., 2.], dtype=np.float32))
388+
model.set_initializer("scales_resize2", np.array([1.0, 1.0, 2.0], dtype=np.float32))
388389
return model
389390

391+
390392
@pytest.mark.parametrize("test_model", ["Quartz", "VGG", "ConvUpsample", "Resize"])
391393
def test_4d_conversion(test_model):
392394
"""

0 commit comments

Comments
 (0)