Skip to content

Commit e26f4d1

Browse files
chinhuang007tjingrant
authored andcommitted
Fix slice negative starts (#275)
Fix slice negative starts, which is documented in ONNX https://github.com/onnx/onnx/blob/master/docs/Operators.md#Slice Also add unit test for slice.
1 parent 6013e3a commit e26f4d1

File tree

2 files changed

+20
-6
lines changed

2 files changed

+20
-6
lines changed

onnx_tf/handlers/backend/slice.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@ def version_1(cls, node, **kwargs):
2424
axes = node.attrs.get("axes", list(range(slice_len)))
2525

2626
for i in range(slice_len):
27+
starts[i] = full_sizes[
28+
axes[i]] + starts[i] if starts[i] < 0 else starts[i]
2729
ends[i] = full_sizes[axes[i]] + ends[i] if ends[i] < 0 else ends[i]
2830
if full_sizes[axes[i]] is not None:
2931
ends[i] = np.min([full_sizes[axes[i]], ends[i]])

test/backend/test_node.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -307,8 +307,8 @@ def test_dot(self):
307307
def test_dynamic_slice(self):
308308
if defs.onnx_opset_version() < 9:
309309
raise unittest.SkipTest(
310-
"ONNX version {} doesn't support DynamicSlice."
311-
.format(defs.onnx_opset_version()))
310+
"ONNX version {} doesn't support DynamicSlice.".format(
311+
defs.onnx_opset_version()))
312312
axes = np.array([0, 1], dtype=np.long)
313313
starts = np.array([1, 0], dtype=np.long)
314314
ends = np.array([2, 3], dtype=np.long)
@@ -737,12 +737,24 @@ def test_size(self):
737737
np.testing.assert_almost_equal(output["Y"], np.size(x))
738738

739739
def test_slice(self):
740-
# TODO: API update or fix onnx version
741-
return
742-
node_def = helper.make_node("Slice", ["X", "Y", "Z", "W"], ["S"])
740+
# test case 1 with normal inputs
741+
axes = [0, 1, 2]
742+
starts = [0, 0, 0]
743+
ends = [2, 2, 2]
744+
node_def = helper.make_node(
745+
"Slice", ["X"], ["S"], axes=axes, starts=starts, ends=ends)
743746
x = self._get_rnd([1000]).reshape([10, 10, 10])
744-
output = run_node(node_def, [x, [0, 1, 2], [0, 0, 0], [2, 2, 2]])
747+
output = run_node(node_def, [x])
745748
np.testing.assert_almost_equal(output["S"], x[0:2, 0:2, 0:2])
749+
# test case 2 with negative, out-of-bound and default inputs
750+
axes = [0, 2]
751+
starts = [0, -7]
752+
ends = [-8, 20]
753+
node_def = helper.make_node(
754+
"Slice", ["X"], ["S"], axes=axes, starts=starts, ends=ends)
755+
x = self._get_rnd([1000]).reshape([10, 10, 10])
756+
output = run_node(node_def, [x])
757+
np.testing.assert_almost_equal(output["S"], x[0:-8, :, -7:20])
746758

747759
def test_softplus(self):
748760
node_def = helper.make_node("Softplus", ["X"], ["Y"])

0 commit comments

Comments
 (0)