Skip to content
This repository was archived by the owner on Feb 7, 2023. It is now read-only.

Commit 96773d1

Browse files
authored
Fix for converting Tile Slice (#495)
* Fix for converting `Tile` `Slice` - Tile: pass `reps` as a int list instead of numpy array to `add_tile - Slice: If shape is not known but rank is known, proceed with end_mask false * adding pytorch model test
1 parent 2000d40 commit 96773d1

File tree

2 files changed

+28
-6
lines changed

2 files changed

+28
-6
lines changed

onnx_coreml/_operators_nd.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1890,8 +1890,15 @@ def _convert_slice_ir4v9(builder, node, graph, err):
18901890
'''
18911891
convert to CoreML Slice Static Layer:
18921892
https://github.com/apple/coremltools/blob/655b3be5cc0d42c3c4fa49f0f0e4a93a26b3e492/mlmodel/format/NeuralNetwork.proto#L5082
1893-
'''
1894-
data_shape = graph.shape_dict[node.inputs[0]]
1893+
'''
1894+
if node.inputs[0] in graph.shape_dict:
1895+
data_shape = graph.shape_dict[node.inputs[0]]
1896+
else:
1897+
rank = builder._get_rank(node.inputs[0])
1898+
if rank == -1:
1899+
return err.unsupported_op_configuration(builder, node, graph, "Input shape not available")
1900+
data_shape = [INT_MAX] * rank
1901+
18951902
len_of_data = len(data_shape)
18961903
begin_masks = [True] * len_of_data
18971904
end_masks = [True] * len_of_data
@@ -2160,7 +2167,7 @@ def _convert_tile(builder, node, graph, err):
21602167
name=node.name,
21612168
input_name=node.inputs[0],
21622169
output_name=node.outputs[0],
2163-
reps=node.input_tensors[node.inputs[1]]
2170+
reps=node.input_tensors[node.inputs[1]].astype(np.int32).tolist()
21642171
)
21652172

21662173
def _convert_topk(builder, node, graph, err):

tests/pytorch_model_test.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -572,6 +572,20 @@ def forward(self, x):
572572
torch_model.train(False)
573573
_test_torch_model_single_io(torch_model, (18, 4, 5), (18, 4, 5), target_ios=target_ios) # type: ignore
574574

575+
class OperatorTests(unittest.TestCase):
576+
'''
577+
Operator test for Operator
578+
'''
579+
@unittest.skipIf(macos_version() < MIN_MACOS_VERSION_10_15,
580+
'macOS 10.15+ required. Skipping test.')
581+
def test_repeat(self, target_ios='13'):
582+
class Net(nn.Module):
583+
def forward(self, x):
584+
return x.repeat([2, 3, 1])
585+
torch_model = Net()
586+
torch_model.train(False)
587+
_test_torch_model_single_io(torch_model, (18, 4, 5), (18, 4, 5), target_ios=target_ios) # type: ignore
588+
575589
class BinaryOperationTests(unittest.TestCase):
576590
'''
577591
Binary Operation Test cases
@@ -600,7 +614,7 @@ def forward(self, x):
600614
y1 = torch.rand((4, 5))
601615
y2 = torch.rand((18, 4, 5))
602616
y3 = 7.234
603-
617+
604618
torch_model = Net() # type: ignore
605619
torch_model.train(False)
606620
_test_torch_model_single_io(torch_model, (18, 4, 5), (18, 4, 5), target_ios=target_ios) # type: ignore
@@ -653,7 +667,7 @@ def forward(self, x):
653667
y1 = torch.rand((4, 5))
654668
y2 = torch.rand((18, 4, 5))
655669
y3 = 7.234
656-
670+
657671
torch_model = Net() # type: ignore
658672
torch_model.train(False)
659673
_test_torch_model_single_io(torch_model, (18, 4, 5), (18, 4, 5), target_ios=target_ios) # type: ignore
@@ -696,7 +710,7 @@ def forward(self, x):
696710
e = torch.rand((5))
697711
f = 8.234
698712
g = 5
699-
713+
700714
torch_model = Net() # type: ignore
701715
torch_model.train(False)
702716
_test_torch_model_single_io(torch_model, (18, 4, 5), (18, 4, 5), target_ios=target_ios) # type: ignore
@@ -742,6 +756,7 @@ def test_cast_removal_transformation(self, target_ios='13'):
742756
torch_model.train(False)
743757
_test_torch_model_single_io(torch_model, (1, 18, 4, 5), (1, 18, 8, 10), target_ios=target_ios)
744758

759+
745760
if __name__ == '__main__':
746761
unittest.main()
747762
#suite = unittest.TestSuite()

0 commit comments

Comments
 (0)