|
1 | 1 | from __future__ import annotations
|
2 | 2 |
|
3 | 3 | from ._dtypes import _floating_dtypes, _numeric_dtypes
|
| 4 | +from ._manipulation_functions import reshape |
4 | 5 | from ._array_object import Array
|
5 | 6 |
|
| 7 | +from numpy.core.numeric import normalize_axis_tuple |
| 8 | + |
6 | 9 | from typing import TYPE_CHECKING
|
7 | 10 | if TYPE_CHECKING:
|
8 | 11 | from ._typing import Literal, Optional, Sequence, Tuple, Union
|
@@ -395,18 +398,38 @@ def vector_norm(x: Array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = No
|
395 | 398 | if x.dtype not in _floating_dtypes:
|
396 | 399 | raise TypeError('Only floating-point dtypes are allowed in norm')
|
397 | 400 |
|
| 401 | + # np.linalg.norm tries to do a matrix norm whenever axis is a 2-tuple or |
| 402 | + # when axis=None and the input is 2-D, so to force a vector norm, we make |
| 403 | + # it so the input is 1-D (for axis=None), or reshape so that norm is done |
| 404 | + # on a single dimension. |
398 | 405 | a = x._array
|
399 | 406 | if axis is None:
|
400 |
| - a = a.flatten() |
401 |
| - axis = 0 |
| 407 | + # Note: np.linalg.norm() doesn't handle 0-D arrays |
| 408 | + a = a.ravel() |
| 409 | + _axis = 0 |
402 | 410 | elif isinstance(axis, tuple):
|
403 |
| - # Note: The axis argument supports any number of axes, whereas norm() |
404 |
| - # only supports a single axis for vector norm. |
405 |
| - rest = tuple(i for i in range(a.ndim) if i not in axis) |
| 411 | + # Note: The axis argument supports any number of axes, whereas |
| 412 | + # np.linalg.norm() only supports a single axis for vector norm. |
| 413 | + normalized_axis = normalize_axis_tuple(axis, x.ndim) |
| 414 | + rest = tuple(i for i in range(a.ndim) if i not in normalized_axis) |
406 | 415 | newshape = axis + rest
|
407 |
| - a = np.transpose(a, newshape).reshape((np.prod([a.shape[i] for i in axis]), *[a.shape[i] for i in rest])) |
408 |
| - axis = 0 |
409 |
| - return Array._new(np.linalg.norm(a, axis=axis, keepdims=keepdims, ord=ord)) |
| 416 | + a = np.transpose(a, newshape).reshape( |
| 417 | + (np.prod([a.shape[i] for i in axis], dtype=int), *[a.shape[i] for i in rest])) |
| 418 | + _axis = 0 |
| 419 | + else: |
| 420 | + _axis = axis |
| 421 | + |
| 422 | + res = Array._new(np.linalg.norm(a, axis=_axis, ord=ord)) |
| 423 | + |
| 424 | + if keepdims: |
| 425 | + # We can't reuse np.linalg.norm(keepdims) because of the reshape hacks |
| 426 | + # above to avoid matrix norm logic. |
| 427 | + shape = list(x.shape) |
| 428 | + _axis = normalize_axis_tuple(range(x.ndim) if axis is None else axis, x.ndim) |
| 429 | + for i in _axis: |
| 430 | + shape[i] = 1 |
| 431 | + res = reshape(res, tuple(shape)) |
410 | 432 |
|
| 433 | + return res |
411 | 434 |
|
412 | 435 | __all__ = ['cholesky', 'cross', 'det', 'diagonal', 'eigh', 'eigvalsh', 'inv', 'matmul', 'matrix_norm', 'matrix_power', 'matrix_rank', 'matrix_transpose', 'outer', 'pinv', 'qr', 'slogdet', 'solve', 'svd', 'svdvals', 'tensordot', 'trace', 'vecdot', 'vector_norm']
|
0 commit comments