Skip to content

Commit ac15c7d

Browse files
committed
[Test] enhance test_4d_conversion with opset=11,13 Resize variants
1 parent f25557b commit ac15c7d

File tree

1 file changed

+16
-7
lines changed

1 file changed

+16
-7
lines changed

tests/transformation/test_4d_conversion.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
import numpy as np
3232
import onnx
3333
import onnx.parser as oprs
34+
from onnx.helper import make_opsetid
3435

3536
import qonnx.core.onnx_exec as oxe
3637
from qonnx.core.datatype import DataType
@@ -328,7 +329,7 @@ def create_conv_upsample():
328329
return model
329330

330331

331-
def create_resize():
332+
def create_resize(opset):
332333
"""
333334
Creates an model for testing the 3D to 4D transform of the resize node.
334335
"""
@@ -380,16 +381,24 @@ def create_resize():
380381
value_info=list_of_value_infos,
381382
)
382383

383-
onnx_model = qonnx_make_model(graph, producer_name="4d_conversion_resize_test-model")
384+
onnx_model = qonnx_make_model(
385+
graph, producer_name="4d_conversion_resize_test-model", opset_imports=[make_opsetid("", opset)]
386+
)
384387
model = ModelWrapper(onnx_model)
385-
model = model.transform(InferShapes())
388+
386389
model.set_initializer("sizes_resize1", np.array([1, 32, 8], dtype=np.int64))
387-
model.set_initializer("scales_resize1", np.array([], dtype=np.float32))
390+
if opset == 11:
391+
model.set_initializer("scales_resize1", np.array([], dtype=np.float32))
392+
elif opset == 13:
393+
model.graph.node[0].input[2] = ""
394+
else:
395+
assert False, f"Undefined opset {opset} for Resize testcase creator"
388396
model.set_initializer("scales_resize2", np.array([1.0, 1.0, 2.0], dtype=np.float32))
397+
model = model.transform(InferShapes())
389398
return model
390399

391400

392-
@pytest.mark.parametrize("test_model", ["Quartz", "VGG", "ConvUpsample", "Resize"])
401+
@pytest.mark.parametrize("test_model", ["Quartz", "VGG", "ConvUpsample", "Resize11", "Resize13"])
393402
def test_4d_conversion(test_model):
394403
"""
395404
Test for the 3D to 4D transformation with a valid graph.
@@ -401,8 +410,8 @@ def test_4d_conversion(test_model):
401410
model = create_arbitrary_model_vgg()
402411
elif test_model == "ConvUpsample":
403412
model = create_conv_upsample()
404-
elif test_model == "Resize":
405-
model = create_resize()
413+
elif "Resize" in test_model:
414+
model = create_resize(opset=int(test_model.replace("Resize", "")))
406415
else:
407416
raise Exception("Unknown test_model in test_4d_conversion")
408417

0 commit comments

Comments
 (0)