@@ -201,6 +201,21 @@ def _axis_none_keepdims(x, ndim, keepdims):
201
201
x = torch .unsqueeze (x , 0 )
202
202
return x
203
203
204
+ def _reduce_multiple_axes (f , x , axis , keepdims = False , ** kwargs ):
205
+ # Some reductions don't support multiple axes
206
+ # (https://github.com/pytorch/pytorch/issues/56586).
207
+ axes = _normalize_axes (axis , x .ndim )
208
+ for a in reversed (axes ):
209
+ x = torch .movedim (x , a , - 1 )
210
+ x = torch .flatten (x , - len (axes ))
211
+
212
+ out = f (x , - 1 , ** kwargs )
213
+
214
+ if keepdims :
215
+ for a in axes :
216
+ out = torch .unsqueeze (out , a )
217
+ return out
218
+
204
219
def prod (x : array ,
205
220
/ ,
206
221
* ,
@@ -226,14 +241,7 @@ def prod(x: array,
226
241
# torch.prod doesn't support multiple axes
227
242
# (https://github.com/pytorch/pytorch/issues/56586).
228
243
if isinstance (axis , tuple ):
229
- axes = _normalize_axes (axis , x .ndim )
230
- for i , a in enumerate (axes ):
231
- if keepdims :
232
- x = torch .prod (x , a , dtype = dtype , ** kwargs )
233
- x = torch .unsqueeze (x , a )
234
- else :
235
- x = torch .prod (x , a - i , dtype = dtype , ** kwargs )
236
- return x
244
+ return _reduce_multiple_axes (torch .prod , x , axis , keepdims = keepdims , dtype = dtype , ** kwargs )
237
245
if axis is None :
238
246
# torch doesn't support keepdims with axis=None
239
247
# (https://github.com/pytorch/pytorch/issues/71209)
@@ -281,21 +289,15 @@ def any(x: array,
281
289
axis : Optional [Union [int , Tuple [int , ...]]] = None ,
282
290
keepdims : bool = False ,
283
291
** kwargs ) -> array :
284
- # torch.any doesn't support multiple axes
285
- # (https://github.com/pytorch/pytorch/issues/56586).
286
292
x = torch .asarray (x )
287
293
ndim = x .ndim
288
294
if axis == ():
289
295
return x .to (torch .bool )
296
+ # torch.any doesn't support multiple axes
297
+ # (https://github.com/pytorch/pytorch/issues/56586).
290
298
if isinstance (axis , tuple ):
291
- axes = _normalize_axes (axis , x .ndim )
292
- for i , a in enumerate (axes ):
293
- if keepdims :
294
- x = torch .any (x , a , ** kwargs )
295
- x = torch .unsqueeze (x , a )
296
- else :
297
- x = torch .any (x , a - i , ** kwargs )
298
- return x .to (torch .bool )
299
+ res = _reduce_multiple_axes (torch .any , x , axis , keepdims = keepdims , ** kwargs )
300
+ return res .to (torch .bool )
299
301
if axis is None :
300
302
# torch doesn't support keepdims with axis=None
301
303
# (https://github.com/pytorch/pytorch/issues/71209)
@@ -312,21 +314,15 @@ def all(x: array,
312
314
axis : Optional [Union [int , Tuple [int , ...]]] = None ,
313
315
keepdims : bool = False ,
314
316
** kwargs ) -> array :
315
- # torch.all doesn't support multiple axes
316
- # (https://github.com/pytorch/pytorch/issues/56586).
317
317
x = torch .asarray (x )
318
318
ndim = x .ndim
319
319
if axis == ():
320
320
return x .to (torch .bool )
321
+ # torch.all doesn't support multiple axes
322
+ # (https://github.com/pytorch/pytorch/issues/56586).
321
323
if isinstance (axis , tuple ):
322
- axes = _normalize_axes (axis , ndim )
323
- for i , a in enumerate (axes ):
324
- if keepdims :
325
- x = torch .all (x , a , ** kwargs )
326
- x = torch .unsqueeze (x , a )
327
- else :
328
- x = torch .all (x , a - i , ** kwargs )
329
- return x .to (torch .bool )
324
+ res = _reduce_multiple_axes (torch .all , x , axis , keepdims = keepdims , ** kwargs )
325
+ return res .to (torch .bool )
330
326
if axis is None :
331
327
# torch doesn't support keepdims with axis=None
332
328
# (https://github.com/pytorch/pytorch/issues/71209)
0 commit comments