Skip to content

Commit 5545635

Browse files
committed
Update torch reduction functions that don't support multiple axes
Instead of applying them multiple times, we move the dimensions to the end and flatten, and apply them once.
1 parent bb4d3af commit 5545635

File tree

1 file changed

+24
-28
lines changed

1 file changed

+24
-28
lines changed

array_api_compat/torch/_aliases.py

Lines changed: 24 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,21 @@ def _axis_none_keepdims(x, ndim, keepdims):
201201
x = torch.unsqueeze(x, 0)
202202
return x
203203

204+
def _reduce_multiple_axes(f, x, axis, keepdims=False, **kwargs):
205+
# Some reductions don't support multiple axes
206+
# (https://github.com/pytorch/pytorch/issues/56586).
207+
axes = _normalize_axes(axis, x.ndim)
208+
for a in reversed(axes):
209+
x = torch.movedim(x, a, -1)
210+
x = torch.flatten(x, -len(axes))
211+
212+
out = f(x, -1, **kwargs)
213+
214+
if keepdims:
215+
for a in axes:
216+
out = torch.unsqueeze(out, a)
217+
return out
218+
204219
def prod(x: array,
205220
/,
206221
*,
@@ -226,14 +241,7 @@ def prod(x: array,
226241
# torch.prod doesn't support multiple axes
227242
# (https://github.com/pytorch/pytorch/issues/56586).
228243
if isinstance(axis, tuple):
229-
axes = _normalize_axes(axis, x.ndim)
230-
for i, a in enumerate(axes):
231-
if keepdims:
232-
x = torch.prod(x, a, dtype=dtype, **kwargs)
233-
x = torch.unsqueeze(x, a)
234-
else:
235-
x = torch.prod(x, a - i, dtype=dtype, **kwargs)
236-
return x
244+
return _reduce_multiple_axes(torch.prod, x, axis, keepdims=keepdims, dtype=dtype, **kwargs)
237245
if axis is None:
238246
# torch doesn't support keepdims with axis=None
239247
# (https://github.com/pytorch/pytorch/issues/71209)
@@ -281,21 +289,15 @@ def any(x: array,
281289
axis: Optional[Union[int, Tuple[int, ...]]] = None,
282290
keepdims: bool = False,
283291
**kwargs) -> array:
284-
# torch.any doesn't support multiple axes
285-
# (https://github.com/pytorch/pytorch/issues/56586).
286292
x = torch.asarray(x)
287293
ndim = x.ndim
288294
if axis == ():
289295
return x.to(torch.bool)
296+
# torch.any doesn't support multiple axes
297+
# (https://github.com/pytorch/pytorch/issues/56586).
290298
if isinstance(axis, tuple):
291-
axes = _normalize_axes(axis, x.ndim)
292-
for i, a in enumerate(axes):
293-
if keepdims:
294-
x = torch.any(x, a, **kwargs)
295-
x = torch.unsqueeze(x, a)
296-
else:
297-
x = torch.any(x, a - i, **kwargs)
298-
return x.to(torch.bool)
299+
res = _reduce_multiple_axes(torch.any, x, axis, keepdims=keepdims, **kwargs)
300+
return res.to(torch.bool)
299301
if axis is None:
300302
# torch doesn't support keepdims with axis=None
301303
# (https://github.com/pytorch/pytorch/issues/71209)
@@ -312,21 +314,15 @@ def all(x: array,
312314
axis: Optional[Union[int, Tuple[int, ...]]] = None,
313315
keepdims: bool = False,
314316
**kwargs) -> array:
315-
# torch.all doesn't support multiple axes
316-
# (https://github.com/pytorch/pytorch/issues/56586).
317317
x = torch.asarray(x)
318318
ndim = x.ndim
319319
if axis == ():
320320
return x.to(torch.bool)
321+
# torch.all doesn't support multiple axes
322+
# (https://github.com/pytorch/pytorch/issues/56586).
321323
if isinstance(axis, tuple):
322-
axes = _normalize_axes(axis, ndim)
323-
for i, a in enumerate(axes):
324-
if keepdims:
325-
x = torch.all(x, a, **kwargs)
326-
x = torch.unsqueeze(x, a)
327-
else:
328-
x = torch.all(x, a - i, **kwargs)
329-
return x.to(torch.bool)
324+
res = _reduce_multiple_axes(torch.all, x, axis, keepdims=keepdims, **kwargs)
325+
return res.to(torch.bool)
330326
if axis is None:
331327
# torch doesn't support keepdims with axis=None
332328
# (https://github.com/pytorch/pytorch/issues/71209)

0 commit comments

Comments
 (0)