@@ -188,9 +188,13 @@ def _normalize_axes(axis, ndim):
188
188
axes .append (a )
189
189
return sorted (axes )
190
190
191
- def _apply_keepdims (x , ndim , keepdims ):
191
+ def _axis_none_keepdims (x , ndim , keepdims ):
192
+ # Apply keepdims when axis=None
193
+ # (https://github.com/pytorch/pytorch/issues/71209)
194
+ # Note that this is only valid for the axis=None case.
192
195
if keepdims :
193
- return x [(None ,)* ndim ]
196
+ for i in range (ndim ):
197
+ x = torch .unsqueeze (x , 0 )
194
198
return x
195
199
196
200
def prod (x : array ,
@@ -230,7 +234,7 @@ def prod(x: array,
230
234
# torch doesn't support keepdims with axis=None
231
235
# (https://github.com/pytorch/pytorch/issues/71209)
232
236
res = torch .prod (x , dtype = dtype , ** kwargs )
233
- res = _apply_keepdims (res , ndim , keepdims )
237
+ res = _axis_none_keepdims (res , ndim , keepdims )
234
238
return res
235
239
236
240
return torch .prod (x , axis , dtype = dtype , keepdims = keepdims , ** kwargs )
@@ -262,7 +266,7 @@ def sum(x: array,
262
266
# torch doesn't support keepdims with axis=None
263
267
# (https://github.com/pytorch/pytorch/issues/71209)
264
268
res = torch .sum (x , dtype = dtype , ** kwargs )
265
- res = _apply_keepdims (res , ndim , keepdims )
269
+ res = _axis_none_keepdims (res , ndim , keepdims )
266
270
return res
267
271
268
272
return torch .sum (x , axis , dtype = dtype , keepdims = keepdims , ** kwargs )
@@ -292,7 +296,7 @@ def any(x: array,
292
296
# torch doesn't support keepdims with axis=None
293
297
# (https://github.com/pytorch/pytorch/issues/71209)
294
298
res = torch .any (x , ** kwargs )
295
- res = _apply_keepdims (res , ndim , keepdims )
299
+ res = _axis_none_keepdims (res , ndim , keepdims )
296
300
return res .to (torch .bool )
297
301
298
302
# torch.any doesn't return bool for uint8
@@ -323,7 +327,7 @@ def all(x: array,
323
327
# torch doesn't support keepdims with axis=None
324
328
# (https://github.com/pytorch/pytorch/issues/71209)
325
329
res = torch .all (x , ** kwargs )
326
- res = _apply_keepdims (res , ndim , keepdims )
330
+ res = _axis_none_keepdims (res , ndim , keepdims )
327
331
return res .to (torch .bool )
328
332
329
333
# torch.all doesn't return bool for uint8
@@ -342,7 +346,7 @@ def mean(x: array,
342
346
# torch doesn't support keepdims with axis=None
343
347
# (https://github.com/pytorch/pytorch/issues/71209)
344
348
res = torch .mean (x , ** kwargs )
345
- res = _apply_keepdims (res , x .ndim , keepdims )
349
+ res = _axis_none_keepdims (res , x .ndim , keepdims )
346
350
return res
347
351
return torch .mean (x , axis , keepdims = keepdims , ** kwargs )
348
352
@@ -369,7 +373,7 @@ def std(x: array,
369
373
# torch doesn't support keepdims with axis=None
370
374
# (https://github.com/pytorch/pytorch/issues/71209)
371
375
res = torch .std (x , tuple (range (x .ndim )), correction = correction , ** kwargs )
372
- res = _apply_keepdims (res , x .ndim , keepdims )
376
+ res = _axis_none_keepdims (res , x .ndim , keepdims )
373
377
return res
374
378
return torch .std (x , axis , correction = correction , keepdims = keepdims , ** kwargs )
375
379
@@ -396,7 +400,7 @@ def var(x: array,
396
400
# torch doesn't support keepdims with axis=None
397
401
# (https://github.com/pytorch/pytorch/issues/71209)
398
402
res = torch .var (x , tuple (range (x .ndim )), correction = correction , ** kwargs )
399
- res = _apply_keepdims (res , x .ndim , keepdims )
403
+ res = _axis_none_keepdims (res , x .ndim , keepdims )
400
404
return res
401
405
return torch .var (x , axis , correction = correction , keepdims = keepdims , ** kwargs )
402
406
0 commit comments