31
31
import numpy as np
32
32
import onnx
33
33
import onnx .parser as oprs
34
+ from onnx .helper import make_opsetid
34
35
35
36
import qonnx .core .onnx_exec as oxe
36
37
from qonnx .core .datatype import DataType
@@ -328,7 +329,7 @@ def create_conv_upsample():
328
329
return model
329
330
330
331
331
- def create_resize ():
332
+ def create_resize (opset ):
332
333
"""
333
334
Creates an model for testing the 3D to 4D transform of the resize node.
334
335
"""
@@ -380,16 +381,24 @@ def create_resize():
380
381
value_info = list_of_value_infos ,
381
382
)
382
383
383
- onnx_model = qonnx_make_model (graph , producer_name = "4d_conversion_resize_test-model" )
384
+ onnx_model = qonnx_make_model (
385
+ graph , producer_name = "4d_conversion_resize_test-model" , opset_imports = [make_opsetid ("" , opset )]
386
+ )
384
387
model = ModelWrapper (onnx_model )
385
- model = model . transform ( InferShapes ())
388
+
386
389
model .set_initializer ("sizes_resize1" , np .array ([1 , 32 , 8 ], dtype = np .int64 ))
387
- model .set_initializer ("scales_resize1" , np .array ([], dtype = np .float32 ))
390
+ if opset == 11 :
391
+ model .set_initializer ("scales_resize1" , np .array ([], dtype = np .float32 ))
392
+ elif opset == 13 :
393
+ model .graph .node [0 ].input [2 ] = ""
394
+ else :
395
+ assert False , f"Undefined opset { opset } for Resize testcase creator"
388
396
model .set_initializer ("scales_resize2" , np .array ([1.0 , 1.0 , 2.0 ], dtype = np .float32 ))
397
+ model = model .transform (InferShapes ())
389
398
return model
390
399
391
400
392
- @pytest .mark .parametrize ("test_model" , ["Quartz" , "VGG" , "ConvUpsample" , "Resize " ])
401
+ @pytest .mark .parametrize ("test_model" , ["Quartz" , "VGG" , "ConvUpsample" , "Resize11" , "Resize13 " ])
393
402
def test_4d_conversion (test_model ):
394
403
"""
395
404
Test for the 3D to 4D transformation with a valid graph.
@@ -401,8 +410,8 @@ def test_4d_conversion(test_model):
401
410
model = create_arbitrary_model_vgg ()
402
411
elif test_model == "ConvUpsample" :
403
412
model = create_conv_upsample ()
404
- elif test_model == "Resize" :
405
- model = create_resize ()
413
+ elif "Resize" in test_model :
414
+ model = create_resize (opset = int ( test_model . replace ( "Resize" , "" )) )
406
415
else :
407
416
raise Exception ("Unknown test_model in test_4d_conversion" )
408
417
0 commit comments