Skip to content

Commit 24c0ea3

Browse files
committed
Use a better function name and use unsqueeze instead of None indexing
1 parent d38bad5 commit 24c0ea3

File tree

1 file changed

+13
-9
lines changed

1 file changed

+13
-9
lines changed

array_api_compat/torch/_aliases.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -188,9 +188,13 @@ def _normalize_axes(axis, ndim):
188188
axes.append(a)
189189
return sorted(axes)
190190

191-
def _apply_keepdims(x, ndim, keepdims):
191+
def _axis_none_keepdims(x, ndim, keepdims):
192+
# Apply keepdims when axis=None
193+
# (https://github.com/pytorch/pytorch/issues/71209)
194+
# Note that this is only valid for the axis=None case.
192195
if keepdims:
193-
return x[(None,)*ndim]
196+
for i in range(ndim):
197+
x = torch.unsqueeze(x, 0)
194198
return x
195199

196200
def prod(x: array,
@@ -230,7 +234,7 @@ def prod(x: array,
230234
# torch doesn't support keepdims with axis=None
231235
# (https://github.com/pytorch/pytorch/issues/71209)
232236
res = torch.prod(x, dtype=dtype, **kwargs)
233-
res = _apply_keepdims(res, ndim, keepdims)
237+
res = _axis_none_keepdims(res, ndim, keepdims)
234238
return res
235239

236240
return torch.prod(x, axis, dtype=dtype, keepdims=keepdims, **kwargs)
@@ -262,7 +266,7 @@ def sum(x: array,
262266
# torch doesn't support keepdims with axis=None
263267
# (https://github.com/pytorch/pytorch/issues/71209)
264268
res = torch.sum(x, dtype=dtype, **kwargs)
265-
res = _apply_keepdims(res, ndim, keepdims)
269+
res = _axis_none_keepdims(res, ndim, keepdims)
266270
return res
267271

268272
return torch.sum(x, axis, dtype=dtype, keepdims=keepdims, **kwargs)
@@ -292,7 +296,7 @@ def any(x: array,
292296
# torch doesn't support keepdims with axis=None
293297
# (https://github.com/pytorch/pytorch/issues/71209)
294298
res = torch.any(x, **kwargs)
295-
res = _apply_keepdims(res, ndim, keepdims)
299+
res = _axis_none_keepdims(res, ndim, keepdims)
296300
return res.to(torch.bool)
297301

298302
# torch.any doesn't return bool for uint8
@@ -323,7 +327,7 @@ def all(x: array,
323327
# torch doesn't support keepdims with axis=None
324328
# (https://github.com/pytorch/pytorch/issues/71209)
325329
res = torch.all(x, **kwargs)
326-
res = _apply_keepdims(res, ndim, keepdims)
330+
res = _axis_none_keepdims(res, ndim, keepdims)
327331
return res.to(torch.bool)
328332

329333
# torch.all doesn't return bool for uint8
@@ -342,7 +346,7 @@ def mean(x: array,
342346
# torch doesn't support keepdims with axis=None
343347
# (https://github.com/pytorch/pytorch/issues/71209)
344348
res = torch.mean(x, **kwargs)
345-
res = _apply_keepdims(res, x.ndim, keepdims)
349+
res = _axis_none_keepdims(res, x.ndim, keepdims)
346350
return res
347351
return torch.mean(x, axis, keepdims=keepdims, **kwargs)
348352

@@ -369,7 +373,7 @@ def std(x: array,
369373
# torch doesn't support keepdims with axis=None
370374
# (https://github.com/pytorch/pytorch/issues/71209)
371375
res = torch.std(x, tuple(range(x.ndim)), correction=correction, **kwargs)
372-
res = _apply_keepdims(res, x.ndim, keepdims)
376+
res = _axis_none_keepdims(res, x.ndim, keepdims)
373377
return res
374378
return torch.std(x, axis, correction=correction, keepdims=keepdims, **kwargs)
375379

@@ -396,7 +400,7 @@ def var(x: array,
396400
# torch doesn't support keepdims with axis=None
397401
# (https://github.com/pytorch/pytorch/issues/71209)
398402
res = torch.var(x, tuple(range(x.ndim)), correction=correction, **kwargs)
399-
res = _apply_keepdims(res, x.ndim, keepdims)
403+
res = _axis_none_keepdims(res, x.ndim, keepdims)
400404
return res
401405
return torch.var(x, axis, correction=correction, keepdims=keepdims, **kwargs)
402406

0 commit comments

Comments
 (0)