@@ -307,8 +307,8 @@ def test_dot(self):
307
307
def test_dynamic_slice (self ):
308
308
if defs .onnx_opset_version () < 9 :
309
309
raise unittest .SkipTest (
310
- "ONNX version {} doesn't support DynamicSlice."
311
- . format ( defs .onnx_opset_version ()))
310
+ "ONNX version {} doesn't support DynamicSlice." . format (
311
+ defs .onnx_opset_version ()))
312
312
axes = np .array ([0 , 1 ], dtype = np .long )
313
313
starts = np .array ([1 , 0 ], dtype = np .long )
314
314
ends = np .array ([2 , 3 ], dtype = np .long )
@@ -737,12 +737,24 @@ def test_size(self):
737
737
np .testing .assert_almost_equal (output ["Y" ], np .size (x ))
738
738
739
739
def test_slice (self ):
740
- # TODO: API update or fix onnx version
741
- return
742
- node_def = helper .make_node ("Slice" , ["X" , "Y" , "Z" , "W" ], ["S" ])
740
+ # test case 1 with normal inputs
741
+ axes = [0 , 1 , 2 ]
742
+ starts = [0 , 0 , 0 ]
743
+ ends = [2 , 2 , 2 ]
744
+ node_def = helper .make_node (
745
+ "Slice" , ["X" ], ["S" ], axes = axes , starts = starts , ends = ends )
743
746
x = self ._get_rnd ([1000 ]).reshape ([10 , 10 , 10 ])
744
- output = run_node (node_def , [x , [ 0 , 1 , 2 ], [ 0 , 0 , 0 ], [ 2 , 2 , 2 ] ])
747
+ output = run_node (node_def , [x ])
745
748
np .testing .assert_almost_equal (output ["S" ], x [0 :2 , 0 :2 , 0 :2 ])
749
+ # test case 2 with negative, out-of-bound and default inputs
750
+ axes = [0 , 2 ]
751
+ starts = [0 , - 7 ]
752
+ ends = [- 8 , 20 ]
753
+ node_def = helper .make_node (
754
+ "Slice" , ["X" ], ["S" ], axes = axes , starts = starts , ends = ends )
755
+ x = self ._get_rnd ([1000 ]).reshape ([10 , 10 , 10 ])
756
+ output = run_node (node_def , [x ])
757
+ np .testing .assert_almost_equal (output ["S" ], x [0 :- 8 , :, - 7 :20 ])
746
758
747
759
def test_softplus (self ):
748
760
node_def = helper .make_node ("Softplus" , ["X" ], ["Y" ])
0 commit comments