Skip to content

Commit c2b258d

Browse files
einsum: conditionally do squeeze before transpose (#4079)
1 parent fcad24a commit c2b258d

File tree

3 files changed

+38
-2
lines changed

3 files changed

+38
-2
lines changed

src/onnx/parse_einsum.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
/*
22
* The MIT License (MIT)
33
*
4-
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
4+
* Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved.
55
*
66
* Permission is hereby granted, free of charge, to any person obtaining a copy
77
* of this software and associated documentation files (the "Software"), to deal
@@ -619,7 +619,8 @@ struct parse_einsum : op_parser<parse_einsum>
619619
perm.push_back(row_output[i]);
620620
}
621621

622-
op = info.add_instruction(make_op("squeeze", {{"axes", sq_axes}}), op);
622+
if(not sq_axes.empty())
623+
op = info.add_instruction(make_op("squeeze", {{"axes", sq_axes}}), op);
623624

624625
if(not perm.empty())
625626
{
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+

2+
einsum_permute_sd3_test:�
3+
-
4+
xy"Einsum*
5+
equation"nhwpqc->nchpwq�einsum_permute_sd3_testZ#
6+
x
7+

8+

9+
@
10+
@
11+

12+

13+
b#
14+
y
15+

16+

17+

18+
@
19+

20+
@
21+
B

test/onnx/gen_onnx.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2614,6 +2614,20 @@ def einsum_permute_test():
26142614
return ([node], [x], [y])
26152615

26162616

2617+
def einsum_permute_sd3_test():
2618+
x = helper.make_tensor_value_info('x', TensorProto.FLOAT,
2619+
[1, 64, 64, 2, 2, 16])
2620+
y = helper.make_tensor_value_info('y', TensorProto.FLOAT,
2621+
[1, 16, 64, 2, 64, 2])
2622+
2623+
node = onnx.helper.make_node('Einsum',
2624+
inputs=['x'],
2625+
outputs=['y'],
2626+
equation='nhwpqc->nchpwq')
2627+
2628+
return ([node], [x], [y])
2629+
2630+
26172631
@onnx_test()
26182632
def einsum_summation_test():
26192633
x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [2, 3])

0 commit comments

Comments
 (0)