Skip to content

Commit 81a0744

Browse files
committed
Add rudimentary support for "arbitrary" dimensions in MultiThreshold
This allows node execution of MultiThreshold operators with arbitrary number of dimensions, as long as the channel dimension is last. This is necessary to run some verification steps of attention operators which, at least for some intermediate steps, have 3 dimensional data layouts. This does not change the behavior of execution on the already existing 2d and 4d data layouts.
1 parent cadd6b2 commit 81a0744

File tree

1 file changed

+24
-0
lines changed

1 file changed

+24
-0
lines changed

src/qonnx/custom_op/general/multithreshold.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,23 @@ def execute_node(self, context, graph):
133133
pass
134134
else:
135135
raise Exception("Unknown data_layout and input ndim" " combination for MultiThreshold.")
136+
137+
# Remember whether the shape has been modified to handle 1d or 3d data
138+
# layouts
139+
orig_shape = None
140+
# If the input tensor has dimensions not covered by the NC or NCWH data
141+
# layouts, the shape needs to be adapted such that it can be handled by
142+
# multithreshold.
143+
# TODO: Seems like a rather sketchy solution to support arbitrary data
144+
# layouts. This does not even validate the assumption of channel last
145+
# layout.
146+
if v.ndim not in {2, 4}:
147+
# Remember the original shape to be restored later
148+
orig_shape = v.shape
149+
# Assume last dimension to be the channel dimension C and reshape
150+
# into NC layout which is supported by multithreshold
151+
v = v.reshape((-1, v.shape[-1]))
152+
136153
# calculate output
137154
output = multithreshold(v, thresholds, out_scale, out_bias)
138155
# setting context according to output
@@ -145,6 +162,13 @@ def execute_node(self, context, graph):
145162
pass
146163
else:
147164
raise Exception("Unknown data_layout and output ndim" " combination for MultiThreshold.")
165+
166+
# If the shape has been modified to support arbitrary layouts, restore
167+
# the original shape
168+
# TODO: Part of the rather sketchy solution above.
169+
if orig_shape is not None:
170+
output = output.reshape(orig_shape)
171+
148172
context[node.output[0]] = output
149173

150174
def verify_node(self):

0 commit comments

Comments
 (0)