Skip to content

Commit 4e90ddf

Browse files
authored
Merge pull request #1110 from alexlyulkov:al/fixed-cumsum-inplace-flag
Added test for cumsum exclusive inplace
2 parents f509aed + a3ffc35 commit 4e90ddf

File tree

4 files changed

+19
-0
lines changed

4 files changed

+19
-0
lines changed
Binary file not shown.
Binary file not shown.

testdata/dnn/onnx/generate_onnx_models.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2519,6 +2519,25 @@ def forward(self, x):
25192519
x = torch.randn(2, 3, 4)
25202520
save_data_and_model("cumsum_3d_dim_2", x, CumSum(dim=2), version=11)
25212521

2522+
# test: CumSum exclusive layer should not be executed inplace
2523+
dims = h.make_node("Constant", inputs=[], outputs=["dims1"], name="node-c1",
2524+
value=h.make_tensor(name="c1v", data_type=onnx.TensorProto.INT64, dims=[], vals=np.asarray([1, ], dtype=np.int64)))
2525+
one = h.make_node("Constant", inputs=[], outputs=["one1"], name="node-c2",
2526+
value=h.make_tensor(name="c2v", data_type=onnx.TensorProto.FLOAT, dims=[], vals=np.asarray([1, ], dtype=np.float32)))
2527+
2528+
mult = h.make_node("Mul", inputs=["input1", "one1"], outputs=["mul_output1"], name="node-m1")
2529+
cumsum = h.make_node("CumSum", inputs=["mul_output1", "dims1"], outputs=["cumsum_output1"], name="node-r1", exclusive=1)
2530+
2531+
graph = h.make_graph([dims, one, mult, cumsum], "graph123",
2532+
[h.make_tensor_value_info("input1", onnx.TensorProto.FLOAT, [1, 3, 1, 1]),],
2533+
[h.make_tensor_value_info("cumsum_output1", onnx.TensorProto.FLOAT, [1, 3, 1, 1])])
2534+
cumsum_model = h.make_model(graph, producer_name="model_cumsum")
2535+
onnx.checker.check_model(cumsum_model)
2536+
2537+
input_np = np.array([1, 2, 3], dtype=np.float32).reshape(1, 3, 1, 1)
2538+
output_np = np.array([0, 1, 3], dtype=np.float32).reshape(1, 3, 1, 1)
2539+
save_data_and_onnx_model("cumsum_exclusive_inplace", input_np, output_np, cumsum_model)
2540+
25222541
# where layer
25232542
class Where(nn.Module):
25242543
def __init__(self, *args, **kwargs):
Binary file not shown.

0 commit comments

Comments
 (0)