Skip to content

Commit 2b76999

Browse files
authored
Merge pull request #1067 from Abdurrahheem:ash/graph_simplifier
Added split = 1 test for split layer
2 parents 1e890c8 + be05aa0 commit 2b76999

File tree

4 files changed

+11
-0
lines changed

4 files changed

+11
-0
lines changed
176 Bytes
Binary file not shown.
176 Bytes
Binary file not shown.

testdata/dnn/onnx/generate_onnx_models.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -868,6 +868,15 @@ def forward(self, x):
868868
tup = torch.split(x, self.split_size_sections, self.dim)
869869
return torch.cat(tup)
870870

871+
class SimpleSplit(nn.Module):
872+
def forward(self, image):
873+
return torch.cat([img for img in image])
874+
875+
876+
model = SimpleSplit()
877+
input = torch.ones((1, 3, 2, 2))
878+
save_data_and_model("split_0", input, model, version=11)
879+
871880
model = Split()
872881
input = Variable(torch.tensor([1., 2.], dtype=torch.float32))
873882
save_data_and_model("split_1", input, model)
@@ -888,6 +897,8 @@ def forward(self, x):
888897
model = Split(dim=-1, split_size_sections=[1, 2])
889898
save_data_and_model("split_6", input2, model, version=13)
890899

900+
901+
891902
class SplitSizes(nn.Module):
892903
def __init__(self, *args, **kwargs):
893904
super(SplitSizes, self).__init__()

testdata/dnn/onnx/models/split_0.onnx

286 Bytes
Binary file not shown.

0 commit comments

Comments
 (0)