@@ -133,6 +133,23 @@ def execute_node(self, context, graph):
133
133
pass
134
134
else :
135
135
raise Exception ("Unknown data_layout and input ndim" " combination for MultiThreshold." )
136
+
137
+ # Remember whether the shape has been modified to handle 1d or 3d data
138
+ # layouts
139
+ orig_shape = None
140
+ # If the input tensor has dimensions not covered by the NC or NCWH data
141
+ # layouts, the shape needs to be adapted such that it can be handled by
142
+ # multithreshold.
143
+ # TODO: Seems like a rather sketchy solution to support arbitrary data
144
+ # layouts. This does not even validate the assumption of channel last
145
+ # layout.
146
+ if v .ndim not in {2 , 4 }:
147
+ # Remember the original shape to be restored later
148
+ orig_shape = v .shape
149
+ # Assume last dimension to be the channel dimension C and reshape
150
+ # into NC layout which is supported by multithreshold
151
+ v = v .reshape ((- 1 , v .shape [- 1 ]))
152
+
136
153
# calculate output
137
154
output = multithreshold (v , thresholds , out_scale , out_bias )
138
155
# setting context according to output
@@ -145,6 +162,13 @@ def execute_node(self, context, graph):
145
162
pass
146
163
else :
147
164
raise Exception ("Unknown data_layout and output ndim" " combination for MultiThreshold." )
165
+
166
+ # If the shape has been modified to support arbitrary layouts, restore
167
+ # the original shape
168
+ # TODO: Part of the rather sketchy solution above.
169
+ if orig_shape is not None :
170
+ output = output .reshape (orig_shape )
171
+
148
172
context [node .output [0 ]] = output
149
173
150
174
def verify_node (self ):
0 commit comments