Skip to content

Commit 4c0b4d4

Browse files
committed
Add std and var, and define __all__ in _alaises.py
1 parent 9811f3c commit 4c0b4d4

File tree

1 file changed

+52
-25
lines changed

1 file changed

+52
-25
lines changed

numpy_array_api_compat/_aliases.py

Lines changed: 52 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -6,36 +6,30 @@
66

77
from typing import TYPE_CHECKING
88
if TYPE_CHECKING:
9-
from typing import Tuple
9+
from typing import Optional, Tuple, Union
1010
from numpy import ndarray, dtype
1111

1212
from typing import NamedTuple
1313

14-
from numpy import (arccos, arccosh, arcsin, arcsinh, arctan, arctan2, arctanh,
15-
left_shift, invert, right_shift, bool_, concatenate, power,
16-
transpose, unique)
14+
import numpy as np
1715

1816
# Basic renames
19-
acos = arccos
20-
acosh = arccosh
21-
asin = arcsin
22-
asinh = arcsinh
23-
atan = arctan
24-
atan2 = arctan2
25-
atanh = arctanh
26-
bitwise_left_shift = left_shift
27-
bitwise_invert = invert
28-
bitwise_right_shift = right_shift
29-
bool = bool_
30-
concat = concatenate
31-
pow = power
17+
acos = np.arccos
18+
acosh = np.arccosh
19+
asin = np.arcsin
20+
asinh = np.arcsinh
21+
atan = np.arctan
22+
atan2 = np.arctan2
23+
atanh = np.arctanh
24+
bitwise_left_shift = np.left_shift
25+
bitwise_invert = np.invert
26+
bitwise_right_shift = np.right_shift
27+
bool = np.bool_
28+
concat = np.concatenate
29+
pow = np.power
3230

3331
# These functions are modified from the NumPy versions.
3432

35-
# Unlike transpose(), the axes argument to permute_dims() is required.
36-
def permute_dims(x: ndarray, /, axes: Tuple[int, ...]) -> ndarray:
37-
return transpose(x, axes)
38-
3933
# np.unique() is split into four functions in the array API:
4034
# unique_all, unique_counts, unique_inverse, and unique_values (this is done
4135
# to remove polymorphic return types).
@@ -60,7 +54,7 @@ class UniqueInverseResult(NamedTuple):
6054

6155

6256
def unique_all(x: ndarray, /) -> UniqueAllResult:
63-
values, indices, inverse_indices, counts = unique(
57+
values, indices, inverse_indices, counts = np.unique(
6458
x,
6559
return_counts=True,
6660
return_index=True,
@@ -79,7 +73,7 @@ def unique_all(x: ndarray, /) -> UniqueAllResult:
7973

8074

8175
def unique_counts(x: ndarray, /) -> UniqueCountsResult:
82-
res = unique(
76+
res = np.unique(
8377
x,
8478
return_counts=True,
8579
return_index=False,
@@ -91,7 +85,7 @@ def unique_counts(x: ndarray, /) -> UniqueCountsResult:
9185

9286

9387
def unique_inverse(x: ndarray, /) -> UniqueInverseResult:
94-
values, inverse_indices = unique(
88+
values, inverse_indices = np.unique(
9589
x,
9690
return_counts=False,
9791
return_index=False,
@@ -105,7 +99,7 @@ def unique_inverse(x: ndarray, /) -> UniqueInverseResult:
10599

106100

107101
def unique_values(x: ndarray, /) -> ndarray:
108-
return unique(
102+
return np.unique(
109103
x,
110104
return_counts=False,
111105
return_index=False,
@@ -118,5 +112,38 @@ def astype(x: ndarray, dtype: dtype, /, *, copy: bool = True) -> ndarray:
118112
return x
119113
return x.astype(dtype=dtype, copy=copy)
120114

115+
# These functions have different keyword argument names
116+
117+
def std(
118+
x: ndarray,
119+
/,
120+
*,
121+
axis: Optional[Union[int, Tuple[int, ...]]] = None,
122+
correction: Union[int, float] = 0.0, # correction instead of ddof
123+
keepdims: bool = False,
124+
) -> ndarray:
125+
return np.std(x, axis=axis, ddof=correction, keepdims=keepdims)
126+
127+
def var(
128+
x: ndarray,
129+
/,
130+
*,
131+
axis: Optional[Union[int, Tuple[int, ...]]] = None,
132+
correction: Union[int, float] = 0.0, # correction instead of ddof
133+
keepdims: bool = False,
134+
) -> ndarray:
135+
return np.var(x, axis=axis, ddof=correction, keepdims=keepdims)
136+
137+
# Unlike transpose(), the axes argument to permute_dims() is required.
138+
def permute_dims(x: ndarray, /, axes: Tuple[int, ...]) -> ndarray:
139+
return np.transpose(x, axes)
140+
121141
# from numpy import * doesn't overwrite these builtin names
122142
from numpy import abs, max, min, round
143+
144+
__all__ = ['acos', 'acosh', 'asin', 'asinh', 'atan', 'atan2', 'atanh',
145+
'bitwise_left_shift', 'bitwise_invert', 'bitwise_right_shift',
146+
'bool', 'concat', 'pow', 'UniqueAllResult', 'UniqueCountsResult',
147+
'UniqueInverseResult', 'unique_all', 'unique_counts',
148+
'unique_inverse', 'unique_values', 'astype', 'abs', 'max', 'min',
149+
'round', 'std', 'var', 'permute_dims']

0 commit comments

Comments
 (0)