Skip to content

Commit 7b00753

Browse files
committed
Extend the 4d conversion test to include resize
1 parent 0cf9806 commit 7b00753

File tree

1 file changed

+62
-1
lines changed

1 file changed

+62
-1
lines changed

tests/transformation/test_4d_conversion.py

Lines changed: 62 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -327,8 +327,67 @@ def create_conv_upsample():
327327
model.set_initializer(tensor_name, gen_finn_dt_tensor(DataType["FLOAT32"], init_shape))
328328
return model
329329

330+
def create_resize():
331+
"""
332+
Creates an model for testing the 3D to 4D transform of the resize node.
333+
"""
334+
resize_node1 = onnx.helper.make_node(
335+
"Resize",
336+
inputs=["in_resize1", "roi_resize1", "scales_resize1", "sizes_resize1"],
337+
outputs=["out_resize1"],
338+
name="Resize1",
339+
mode="nearest",
340+
)
341+
342+
resize_node2 = onnx.helper.make_node(
343+
"Resize",
344+
inputs=["out_resize1", "roi_resize2", "scales_resize2"],
345+
outputs=["out_resize2"],
346+
name="Resize2",
347+
mode="nearest",
348+
)
349+
350+
in_resize1 = onnx.helper.make_tensor_value_info("in_resize1", onnx.TensorProto.FLOAT, [1, 32, 4])
351+
out_resize1 = onnx.helper.make_tensor_value_info("out_resize1", onnx.TensorProto.FLOAT, [1, 32, 8])
352+
out_resize2 = onnx.helper.make_tensor_value_info("out_resize2", onnx.TensorProto.FLOAT, [1, 32, 16])
353+
354+
roi_resize1 = onnx.helper.make_tensor_value_info("roi_resize1", onnx.TensorProto.FLOAT, [4])
355+
scales_resize1 = onnx.helper.make_tensor_value_info("scales_resize1", onnx.TensorProto.FLOAT, [])
356+
sizes_resize1 = onnx.helper.make_tensor_value_info("sizes_resize1", onnx.TensorProto.INT64, [3])
357+
358+
roi_resize2 = onnx.helper.make_tensor_value_info("roi_resize2", onnx.TensorProto.FLOAT, [4])
359+
scales_resize2 = onnx.helper.make_tensor_value_info("scales_resize2", onnx.TensorProto.FLOAT, [3])
360+
361+
list_of_nodes = [
362+
resize_node1,
363+
resize_node2,
364+
]
365+
list_of_value_infos = [
366+
out_resize1,
367+
roi_resize1,
368+
sizes_resize1,
369+
scales_resize1,
370+
roi_resize2,
371+
scales_resize2,
372+
]
373+
374+
graph = onnx.helper.make_graph(
375+
nodes=list_of_nodes,
376+
name="4d_conversion_resize_test_graph",
377+
inputs=[in_resize1],
378+
outputs=[out_resize2],
379+
value_info=list_of_value_infos,
380+
)
381+
382+
onnx_model = qonnx_make_model(graph, producer_name="4d_conversion_resize_test-model")
383+
model = ModelWrapper(onnx_model)
384+
model = model.transform(InferShapes())
385+
model.set_initializer("sizes_resize1", np.array([1, 32, 8], dtype=np.int64))
386+
model.set_initializer("scales_resize1", np.array([], dtype=np.float32))
387+
model.set_initializer("scales_resize2", np.array([1., 1., 2.], dtype=np.float32))
388+
return model
330389

331-
@pytest.mark.parametrize("test_model", ["Quartz", "VGG", "ConvUpsample"])
390+
@pytest.mark.parametrize("test_model", ["Quartz", "VGG", "ConvUpsample", "Resize"])
332391
def test_4d_conversion(test_model):
333392
"""
334393
Test for the 3D to 4D transformation with a valid graph.
@@ -340,6 +399,8 @@ def test_4d_conversion(test_model):
340399
model = create_arbitrary_model_vgg()
341400
elif test_model == "ConvUpsample":
342401
model = create_conv_upsample()
402+
elif test_model == "Resize":
403+
model = create_resize()
343404
else:
344405
raise Exception("Unknown test_model in test_4d_conversion")
345406

0 commit comments

Comments
 (0)