From a4327e133f43f917929cef450d03cc9ed5cf620c Mon Sep 17 00:00:00 2001 From: fthielke Date: Tue, 19 Oct 2021 18:06:37 +0200 Subject: [PATCH 1/2] Fixed Conv3DTranspose with strides for data format channels_first (fixes #1714) While shape calculations for the input correctly distinguished between channels_first and channels_last, shape calculations for the inputs of the final Slice and Pad nodes always assumed channels_last format. Signed-off-by: fthielke --- tf2onnx/onnx_opset/nn.py | 33 +++++++++++++++++++++------------ 1 file changed, 21 insertions(+), 12 deletions(-) diff --git a/tf2onnx/onnx_opset/nn.py b/tf2onnx/onnx_opset/nn.py index 46630390d..81543144f 100644 --- a/tf2onnx/onnx_opset/nn.py +++ b/tf2onnx/onnx_opset/nn.py @@ -504,14 +504,15 @@ def version_1(cls, ctx, node, **kwargs): use_strides_workaround = False input_shape = ctx.make_node("Cast", [node.input[0]], attr={'to': TensorProto.INT64}) output_shape = ctx.make_node("Shape", [node.output[0]]) + sp_index_start = 1 if is_channels_last(node) else 2 output_h = GraphBuilder(ctx).make_slice( - {"data": output_shape.output[0], "ends": [2], "starts": [1], "axes": [0]}) + {"data": output_shape.output[0], "ends": [sp_index_start+1], "starts": [sp_index_start], "axes": [0]}) output_w = GraphBuilder(ctx).make_slice( - {"data": output_shape.output[0], "ends": [3], "starts": [2], "axes": [0]}) + {"data": output_shape.output[0], "ends": [sp_index_start+2], "starts": [sp_index_start+1], "axes": [0]}) expect_h = GraphBuilder(ctx).make_slice( - {"data": input_shape.output[0], "ends": [2], "starts": [1], "axes": [0]}) + {"data": input_shape.output[0], "ends": [sp_index_start+1], "starts": [sp_index_start], "axes": [0]}) expect_w = GraphBuilder(ctx).make_slice( - {"data": input_shape.output[0], "ends": [3], "starts": [2], "axes": [0]}) + {"data": input_shape.output[0], "ends": [sp_index_start+2], "starts": [sp_index_start+1], "axes": [0]}) diff_h = ctx.make_node("Sub", [output_h, expect_h]) diff_w = ctx.make_node("Sub", [output_w, expect_w]) nonneg_diff_h = diff_h @@ -528,10 +529,12 @@ def version_1(cls, ctx, node, **kwargs): end_h = ctx.make_node("Add", [start_h.output[0], expect_h]) end_w = ctx.make_node("Add", [start_w.output[0], expect_w]) if spatial == 3: - output_d = GraphBuilder(ctx).make_slice( - {"data": output_shape.output[0], "ends": [4], "starts": [3], "axes": [0]}) - expect_d = GraphBuilder(ctx).make_slice( - {"data": input_shape.output[0], "ends": [4], "starts": [3], "axes": [0]}) + output_d = GraphBuilder(ctx).make_slice({ + "data": output_shape.output[0], "ends": [sp_index_start+3], "starts": [sp_index_start+2], "axes": [0] + }) + expect_d = GraphBuilder(ctx).make_slice({ + "data": input_shape.output[0], "ends": [sp_index_start+3], "starts": [sp_index_start+2], "axes": [0] + }) diff_d = ctx.make_node("Sub", [output_d, expect_d]) nonneg_diff_d = diff_d if use_strides_workaround: @@ -543,12 +546,12 @@ def version_1(cls, ctx, node, **kwargs): attr={"axis": 0}) ends = ctx.make_node("Concat", [end_h.output[0], end_w.output[0], end_d.output[0]], attr={"axis": 0}) slice_axes = ctx.make_const(utils.make_name(node.name + "_const_slice_axes"), - np.array([1, 2, 3], dtype=np.int64)) + np.arange(sp_index_start, sp_index_start + 3, dtype=np.int64)) else: starts = ctx.make_node("Concat", [start_h.output[0], start_w.output[0]], attr={"axis": 0}) ends = ctx.make_node("Concat", [end_h.output[0], end_w.output[0]], attr={"axis": 0}) slice_axes = ctx.make_const(utils.make_name(node.name + "_const_slice_axes"), - np.array([1, 2], dtype=np.int64)) + np.arange(sp_index_start, sp_index_start + 2, dtype=np.int64)) slice_node = ctx.make_node("Slice", [node.output[0], starts.output[0], ends.output[0], slice_axes.output[0]], @@ -571,10 +574,16 @@ def version_1(cls, ctx, node, **kwargs): neg_diff_d = ctx.make_node("Neg", [diff_d.output[0]]) shrink_d_by = ctx.make_node("Max", [neg_diff_d.output[0], const_zero.output[0]]) sdb = shrink_d_by.output[0] - pads = ctx.make_node("Concat", [cz, cz, cz, cz, cz, cz, shb, swb, sdb, cz], attr={"axis": 0}) + if is_channels_last(node): + pads = ctx.make_node("Concat", [cz, cz, cz, cz, cz, cz, shb, swb, sdb, cz], attr={"axis": 0}) + else: + pads = ctx.make_node("Concat", [cz, cz, cz, cz, cz, cz, cz, shb, swb, sdb], attr={"axis": 0}) padded_node = ctx.make_node("Pad", [slice_node.output[0], pads.output[0]]) else: - pads = ctx.make_node("Concat", [cz, cz, cz, cz, cz, shb, swb, cz], attr={"axis": 0}) + if is_channels_last(node): + pads = ctx.make_node("Concat", [cz, cz, cz, cz, cz, shb, swb, cz], attr={"axis": 0}) + else: + pads = ctx.make_node("Concat", [cz, cz, cz, cz, cz, cz, shb, swb], attr={"axis": 0}) padded_node = ctx.make_node("Pad", [slice_node.output[0], pads.output[0]]) final_node = padded_node From abf3695ba336f63d9cc16ecef41b19ec980bfc89 Mon Sep 17 00:00:00 2001 From: Jay Zhang Date: Wed, 12 Oct 2022 14:47:39 +0800 Subject: [PATCH 2/2] Fix pylint issue. Signed-off-by: Jay Zhang --- tf2onnx/onnx_opset/nn.py | 20 ++++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/tf2onnx/onnx_opset/nn.py b/tf2onnx/onnx_opset/nn.py index 81543144f..d63fb7931 100644 --- a/tf2onnx/onnx_opset/nn.py +++ b/tf2onnx/onnx_opset/nn.py @@ -506,13 +506,17 @@ def version_1(cls, ctx, node, **kwargs): output_shape = ctx.make_node("Shape", [node.output[0]]) sp_index_start = 1 if is_channels_last(node) else 2 output_h = GraphBuilder(ctx).make_slice( - {"data": output_shape.output[0], "ends": [sp_index_start+1], "starts": [sp_index_start], "axes": [0]}) + {"data": output_shape.output[0], "ends": [sp_index_start+1], "starts": [sp_index_start], + "axes": [0]}) output_w = GraphBuilder(ctx).make_slice( - {"data": output_shape.output[0], "ends": [sp_index_start+2], "starts": [sp_index_start+1], "axes": [0]}) + {"data": output_shape.output[0], "ends": [sp_index_start+2], "starts": [sp_index_start+1], + "axes": [0]}) expect_h = GraphBuilder(ctx).make_slice( - {"data": input_shape.output[0], "ends": [sp_index_start+1], "starts": [sp_index_start], "axes": [0]}) + {"data": input_shape.output[0], "ends": [sp_index_start+1], "starts": [sp_index_start], + "axes": [0]}) expect_w = GraphBuilder(ctx).make_slice( - {"data": input_shape.output[0], "ends": [sp_index_start+2], "starts": [sp_index_start+1], "axes": [0]}) + {"data": input_shape.output[0], "ends": [sp_index_start+2], "starts": [sp_index_start+1], + "axes": [0]}) diff_h = ctx.make_node("Sub", [output_h, expect_h]) diff_w = ctx.make_node("Sub", [output_w, expect_w]) nonneg_diff_h = diff_h @@ -530,11 +534,11 @@ def version_1(cls, ctx, node, **kwargs): end_w = ctx.make_node("Add", [start_w.output[0], expect_w]) if spatial == 3: output_d = GraphBuilder(ctx).make_slice({ - "data": output_shape.output[0], "ends": [sp_index_start+3], "starts": [sp_index_start+2], "axes": [0] - }) + "data": output_shape.output[0], "ends": [sp_index_start+3], "starts": [sp_index_start+2], + "axes": [0]}) expect_d = GraphBuilder(ctx).make_slice({ - "data": input_shape.output[0], "ends": [sp_index_start+3], "starts": [sp_index_start+2], "axes": [0] - }) + "data": input_shape.output[0], "ends": [sp_index_start+3], "starts": [sp_index_start+2], + "axes": [0]}) diff_d = ctx.make_node("Sub", [output_d, expect_d]) nonneg_diff_d = diff_d if use_strides_workaround: