@@ -327,8 +327,67 @@ def create_conv_upsample():
327
327
model .set_initializer (tensor_name , gen_finn_dt_tensor (DataType ["FLOAT32" ], init_shape ))
328
328
return model
329
329
330
+ def create_resize ():
331
+ """
332
+ Creates an model for testing the 3D to 4D transform of the resize node.
333
+ """
334
+ resize_node1 = onnx .helper .make_node (
335
+ "Resize" ,
336
+ inputs = ["in_resize1" , "roi_resize1" , "scales_resize1" , "sizes_resize1" ],
337
+ outputs = ["out_resize1" ],
338
+ name = "Resize1" ,
339
+ mode = "nearest" ,
340
+ )
341
+
342
+ resize_node2 = onnx .helper .make_node (
343
+ "Resize" ,
344
+ inputs = ["out_resize1" , "roi_resize2" , "scales_resize2" ],
345
+ outputs = ["out_resize2" ],
346
+ name = "Resize2" ,
347
+ mode = "nearest" ,
348
+ )
349
+
350
+ in_resize1 = onnx .helper .make_tensor_value_info ("in_resize1" , onnx .TensorProto .FLOAT , [1 , 32 , 4 ])
351
+ out_resize1 = onnx .helper .make_tensor_value_info ("out_resize1" , onnx .TensorProto .FLOAT , [1 , 32 , 8 ])
352
+ out_resize2 = onnx .helper .make_tensor_value_info ("out_resize2" , onnx .TensorProto .FLOAT , [1 , 32 , 16 ])
353
+
354
+ roi_resize1 = onnx .helper .make_tensor_value_info ("roi_resize1" , onnx .TensorProto .FLOAT , [4 ])
355
+ scales_resize1 = onnx .helper .make_tensor_value_info ("scales_resize1" , onnx .TensorProto .FLOAT , [])
356
+ sizes_resize1 = onnx .helper .make_tensor_value_info ("sizes_resize1" , onnx .TensorProto .INT64 , [3 ])
357
+
358
+ roi_resize2 = onnx .helper .make_tensor_value_info ("roi_resize2" , onnx .TensorProto .FLOAT , [4 ])
359
+ scales_resize2 = onnx .helper .make_tensor_value_info ("scales_resize2" , onnx .TensorProto .FLOAT , [3 ])
360
+
361
+ list_of_nodes = [
362
+ resize_node1 ,
363
+ resize_node2 ,
364
+ ]
365
+ list_of_value_infos = [
366
+ out_resize1 ,
367
+ roi_resize1 ,
368
+ sizes_resize1 ,
369
+ scales_resize1 ,
370
+ roi_resize2 ,
371
+ scales_resize2 ,
372
+ ]
373
+
374
+ graph = onnx .helper .make_graph (
375
+ nodes = list_of_nodes ,
376
+ name = "4d_conversion_resize_test_graph" ,
377
+ inputs = [in_resize1 ],
378
+ outputs = [out_resize2 ],
379
+ value_info = list_of_value_infos ,
380
+ )
381
+
382
+ onnx_model = qonnx_make_model (graph , producer_name = "4d_conversion_resize_test-model" )
383
+ model = ModelWrapper (onnx_model )
384
+ model = model .transform (InferShapes ())
385
+ model .set_initializer ("sizes_resize1" , np .array ([1 , 32 , 8 ], dtype = np .int64 ))
386
+ model .set_initializer ("scales_resize1" , np .array ([], dtype = np .float32 ))
387
+ model .set_initializer ("scales_resize2" , np .array ([1. , 1. , 2. ], dtype = np .float32 ))
388
+ return model
330
389
331
- @pytest .mark .parametrize ("test_model" , ["Quartz" , "VGG" , "ConvUpsample" ])
390
+ @pytest .mark .parametrize ("test_model" , ["Quartz" , "VGG" , "ConvUpsample" , "Resize" ])
332
391
def test_4d_conversion (test_model ):
333
392
"""
334
393
Test for the 3D to 4D transformation with a valid graph.
@@ -340,6 +399,8 @@ def test_4d_conversion(test_model):
340
399
model = create_arbitrary_model_vgg ()
341
400
elif test_model == "ConvUpsample" :
342
401
model = create_conv_upsample ()
402
+ elif test_model == "Resize" :
403
+ model = create_resize ()
343
404
else :
344
405
raise Exception ("Unknown test_model in test_4d_conversion" )
345
406
0 commit comments