Skip to content

Commit 0a3eafd

Browse files
committed
Add several functions which are only slightly modified from numpy
1 parent 5e9b1b0 commit 0a3eafd

File tree

2 files changed

+101
-2
lines changed

2 files changed

+101
-2
lines changed

numpy_array_api_compat/_aliases.py

Lines changed: 100 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,20 @@
22
These are functions that are just aliases of existing functions in NumPy.
33
"""
44

5+
from __future__ import annotations
6+
7+
from typing import TYPE_CHECKING
8+
if TYPE_CHECKING:
9+
from typing import Tuple
10+
from numpy import ndarray, dtype
11+
12+
from typing import NamedTuple
13+
514
from numpy import (arccos, arccosh, arcsin, arcsinh, arctan, arctan2, arctanh,
6-
left_shift, invert, right_shift, bool_, concatenate, power)
15+
left_shift, invert, right_shift, bool_, concatenate, power,
16+
transpose, unique)
717

18+
# Basic renames
819
acos = arccos
920
acosh = arccosh
1021
asin = arcsin
@@ -18,3 +29,91 @@
1829
bool = bool_
1930
concat = concatenate
2031
pow = power
32+
33+
# These functions are modified from the NumPy versions.
34+
35+
# Unlike transpose(), the axes argument to permute_dims() is required.
36+
def permute_dims(x: ndarray, /, axes: Tuple[int, ...]) -> ndarray:
37+
return transpose(x, axes)
38+
39+
# np.unique() is split into four functions in the array API:
40+
# unique_all, unique_counts, unique_inverse, and unique_values (this is done
41+
# to remove polymorphic return types).
42+
43+
# The functions here return namedtuples (np.unique() returns a normal
44+
# tuple).
45+
class UniqueAllResult(NamedTuple):
46+
values: ndarray
47+
indices: ndarray
48+
inverse_indices: ndarray
49+
counts: ndarray
50+
51+
52+
class UniqueCountsResult(NamedTuple):
53+
values: ndarray
54+
counts: ndarray
55+
56+
57+
class UniqueInverseResult(NamedTuple):
58+
values: ndarray
59+
inverse_indices: ndarray
60+
61+
62+
def unique_all(x: ndarray, /) -> UniqueAllResult:
63+
values, indices, inverse_indices, counts = unique(
64+
x,
65+
return_counts=True,
66+
return_index=True,
67+
return_inverse=True,
68+
equal_nan=False,
69+
)
70+
# np.unique() flattens inverse indices, but they need to share x's shape
71+
# See https://github.com/numpy/numpy/issues/20638
72+
inverse_indices = inverse_indices.reshape(x.shape)
73+
return UniqueAllResult(
74+
values,
75+
indices,
76+
inverse_indices,
77+
counts,
78+
)
79+
80+
81+
def unique_counts(x: ndarray, /) -> UniqueCountsResult:
82+
res = unique(
83+
x,
84+
return_counts=True,
85+
return_index=False,
86+
return_inverse=False,
87+
equal_nan=False,
88+
)
89+
90+
return UniqueCountsResult(*res)
91+
92+
93+
def unique_inverse(x: ndarray, /) -> UniqueInverseResult:
94+
values, inverse_indices = unique(
95+
x,
96+
return_counts=False,
97+
return_index=False,
98+
return_inverse=True,
99+
equal_nan=False,
100+
)
101+
# np.unique() flattens inverse indices, but they need to share x's shape
102+
# See https://github.com/numpy/numpy/issues/20638
103+
inverse_indices = inverse_indices.reshape(x.shape)
104+
return UniqueInverseResult(values, inverse_indices)
105+
106+
107+
def unique_values(x: ndarray, /) -> ndarray:
108+
return unique(
109+
x,
110+
return_counts=False,
111+
return_index=False,
112+
return_inverse=False,
113+
equal_nan=False,
114+
)
115+
116+
def astype(x: ndarray, dtype: dtype, /, *, copy: bool = True) -> ndarray:
117+
if not copy and dtype == x.dtype:
118+
return x
119+
return x.astype(dtype=dtype, copy=copy)

numpy_array_api_compat/linalg.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from typing import TYPE_CHECKING
44
if TYPE_CHECKING:
5-
from ._typing import Literal, Optional, Tuple, Union
5+
from typing import Literal, Optional, Tuple, Union
66
from numpy import ndarray
77

88
import numpy as np

0 commit comments

Comments
 (0)