@@ -237,9 +237,21 @@ def transform(self, model, node):
237
237
# zero bias, propagate through, if possible
238
238
# (always possible if scale is scalar)
239
239
try :
240
- newscale = np .broadcast_to (scale , output .shape ) # make sure broadcastable
240
+ if scale .ndim > 1 :
241
+ # undo any broadcast_to
242
+ reduced_scale = _remove_redundant_dims (scale )
243
+ if reduced_scale .shape [- 1 ] == 1 :
244
+ reduced_scale = reduced_scale [..., 0 ]
245
+ if node .attributes ['n_dim' ] == 1 :
246
+ scale_trans = np .transpose (reduced_scale , (1 , 0 ))
247
+ else :
248
+ scale_trans = np .transpose (reduced_scale , (1 , 2 , 0 ))
249
+ newscale = np .broadcast_to (scale_trans , output .shape ) # make sure broadcastable
250
+ can_propagate = True
251
+ else :
252
+ newscale = np .broadcast_to (scale , output .shape ) # make sure broadcastable
253
+ can_propagate = True
241
254
newbias = np .zeros (output .shape )
242
- can_propagate = True
243
255
except ValueError :
244
256
can_propagate = False
245
257
@@ -309,3 +321,14 @@ def transform(self, model, node):
309
321
new_node = model .make_node ('ApplyAlpha' , apply_alpha .name , new_attrs , [x for x in node .outputs ])
310
322
model .insert_node (new_node )
311
323
return True
324
+
325
+
326
+ def _remove_redundant_dims (X ):
327
+ """This is somewhat of the inverse of broadcast-to. It sets the dimension size to 1 if all values are identical"""
328
+
329
+ shape = X .shape
330
+ for i in range (len (shape )):
331
+ reduced = np .expand_dims (np .take (X , 0 , axis = i ), axis = i )
332
+ if np .all (reduced == X ):
333
+ X = reduced
334
+ return X
0 commit comments