Skip to content

Commit cd16b74

Browse files
authored
MAINT: Fix computation of numpy.array_api.linalg.vector_norm (#21084)
* Fix computation of numpy.array_api.linalg.vector_norm Various pieces were incorrect due to a lack of complete coverage of this function in the array API test suite. * Fix the output dtype nonstandard vector norm() Previously it would always give float64 because an internal calculation involved a NumPy scalar and a Python float. The fix is to use a 0-D array instead of a NumPy scalar so that it type promotes with the float correctly. Fixes #21083 I don't have a test for this yet because I'm unclear how exactly to test it. * Clean up the numpy.array_api.linalg.vector_norm code a little bit Original NumPy Commit: 70026c4dde47d89d6a7a4916bfac045e714a5b4f
1 parent e317f73 commit cd16b74

File tree

1 file changed

+31
-8
lines changed

1 file changed

+31
-8
lines changed

array_api_strict/linalg.py

Lines changed: 31 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
11
from __future__ import annotations
22

33
from ._dtypes import _floating_dtypes, _numeric_dtypes
4+
from ._manipulation_functions import reshape
45
from ._array_object import Array
56

7+
from numpy.core.numeric import normalize_axis_tuple
8+
69
from typing import TYPE_CHECKING
710
if TYPE_CHECKING:
811
from ._typing import Literal, Optional, Sequence, Tuple, Union
@@ -395,18 +398,38 @@ def vector_norm(x: Array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = No
395398
if x.dtype not in _floating_dtypes:
396399
raise TypeError('Only floating-point dtypes are allowed in norm')
397400

401+
# np.linalg.norm tries to do a matrix norm whenever axis is a 2-tuple or
402+
# when axis=None and the input is 2-D, so to force a vector norm, we make
403+
# it so the input is 1-D (for axis=None), or reshape so that norm is done
404+
# on a single dimension.
398405
a = x._array
399406
if axis is None:
400-
a = a.flatten()
401-
axis = 0
407+
# Note: np.linalg.norm() doesn't handle 0-D arrays
408+
a = a.ravel()
409+
_axis = 0
402410
elif isinstance(axis, tuple):
403-
# Note: The axis argument supports any number of axes, whereas norm()
404-
# only supports a single axis for vector norm.
405-
rest = tuple(i for i in range(a.ndim) if i not in axis)
411+
# Note: The axis argument supports any number of axes, whereas
412+
# np.linalg.norm() only supports a single axis for vector norm.
413+
normalized_axis = normalize_axis_tuple(axis, x.ndim)
414+
rest = tuple(i for i in range(a.ndim) if i not in normalized_axis)
406415
newshape = axis + rest
407-
a = np.transpose(a, newshape).reshape((np.prod([a.shape[i] for i in axis]), *[a.shape[i] for i in rest]))
408-
axis = 0
409-
return Array._new(np.linalg.norm(a, axis=axis, keepdims=keepdims, ord=ord))
416+
a = np.transpose(a, newshape).reshape(
417+
(np.prod([a.shape[i] for i in axis], dtype=int), *[a.shape[i] for i in rest]))
418+
_axis = 0
419+
else:
420+
_axis = axis
421+
422+
res = Array._new(np.linalg.norm(a, axis=_axis, ord=ord))
423+
424+
if keepdims:
425+
# We can't reuse np.linalg.norm(keepdims) because of the reshape hacks
426+
# above to avoid matrix norm logic.
427+
shape = list(x.shape)
428+
_axis = normalize_axis_tuple(range(x.ndim) if axis is None else axis, x.ndim)
429+
for i in _axis:
430+
shape[i] = 1
431+
res = reshape(res, tuple(shape))
410432

433+
return res
411434

412435
__all__ = ['cholesky', 'cross', 'det', 'diagonal', 'eigh', 'eigvalsh', 'inv', 'matmul', 'matrix_norm', 'matrix_power', 'matrix_rank', 'matrix_transpose', 'outer', 'pinv', 'qr', 'slogdet', 'solve', 'svd', 'svdvals', 'tensordot', 'trace', 'vecdot', 'vector_norm']

0 commit comments

Comments
 (0)