Skip to content

Commit f4748e5

Browse files
committed
Make the basic renames simple aliases instead of wrappers
This removes the overhead for these functions, and also makes it so that they are still NumPy ufuncs with all the corresponding ufunc keywords and methods.
1 parent 6dfc00b commit f4748e5

File tree

3 files changed

+32
-68
lines changed

3 files changed

+32
-68
lines changed

array_api_compat/common/_aliases.py

Lines changed: 4 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -14,43 +14,6 @@
1414

1515
from ._helpers import _check_device, _is_numpy_array, get_namespace
1616

17-
# Basic renames
18-
def acos(x, /, xp):
19-
return xp.arccos(x)
20-
21-
def acosh(x, /, xp):
22-
return xp.arccosh(x)
23-
24-
def asin(x, /, xp):
25-
return xp.arcsin(x)
26-
27-
def asinh(x, /, xp):
28-
return xp.arcsinh(x)
29-
30-
def atan(x, /, xp):
31-
return xp.arctan(x)
32-
33-
def atan2(x1, x2, /, xp):
34-
return xp.arctan2(x1, x2)
35-
36-
def atanh(x, /, xp):
37-
return xp.arctanh(x)
38-
39-
def bitwise_left_shift(x1, x2, /, xp):
40-
return xp.left_shift(x1, x2)
41-
42-
def bitwise_invert(x, /, xp):
43-
return xp.invert(x)
44-
45-
def bitwise_right_shift(x1, x2, /, xp):
46-
return xp.right_shift(x1, x2)
47-
48-
def concat(arrays: Union[Tuple[ndarray, ...], List[ndarray]], /, xp, *, axis: Optional[int] = 0) -> ndarray:
49-
return xp.concatenate(arrays, axis=axis)
50-
51-
def pow(x1, x2, /, xp):
52-
return xp.power(x1, x2)
53-
5417
# These functions are modified from the NumPy versions.
5518

5619
def arange(
@@ -422,10 +385,7 @@ def trunc(x: ndarray, /, xp) -> ndarray:
422385
return x
423386
return xp.trunc(x)
424387

425-
__all__ = ['acos', 'acosh', 'asin', 'asinh', 'atan', 'atan2', 'atanh',
426-
'bitwise_left_shift', 'bitwise_invert', 'bitwise_right_shift',
427-
'concat', 'pow', 'UniqueAllResult', 'UniqueCountsResult',
428-
'UniqueInverseResult', 'unique_all', 'unique_counts',
429-
'unique_inverse', 'unique_values', 'astype', 'std', 'var',
430-
'permute_dims', 'reshape', 'argsort', 'sort', 'sum', 'prod',
431-
'ceil', 'floor', 'trunc']
388+
__all__ = ['UniqueAllResult', 'UniqueCountsResult', 'UniqueInverseResult',
389+
'unique_all', 'unique_counts', 'unique_inverse', 'unique_values',
390+
'astype', 'std', 'var', 'permute_dims', 'reshape', 'argsort',
391+
'sort', 'sum', 'prod', 'ceil', 'floor', 'trunc']

array_api_compat/cupy/_aliases.py

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -13,18 +13,20 @@
1313
import cupy as cp
1414
bool = cp.bool_
1515

16-
acos = get_xp(cp)(_aliases.acos)
17-
acosh = get_xp(cp)(_aliases.acosh)
18-
asin = get_xp(cp)(_aliases.asin)
19-
asinh = get_xp(cp)(_aliases.asinh)
20-
atan = get_xp(cp)(_aliases.atan)
21-
atan2 = get_xp(cp)(_aliases.atan2)
22-
atanh = get_xp(cp)(_aliases.atanh)
23-
bitwise_left_shift = get_xp(cp)(_aliases.bitwise_left_shift)
24-
bitwise_invert = get_xp(cp)(_aliases.bitwise_invert)
25-
bitwise_right_shift = get_xp(cp)(_aliases.bitwise_right_shift)
26-
concat = get_xp(cp)(_aliases.concat)
27-
pow = get_xp(cp)(_aliases.pow)
16+
# Basic renames
17+
acos = cp.arccos
18+
acosh = cp.arccosh
19+
asin = cp.arcsin
20+
asinh = cp.arcsinh
21+
atan = cp.arctan
22+
atan2 = cp.arctan2
23+
atanh = cp.arctanh
24+
bitwise_left_shift = cp.left_shift
25+
bitwise_invert = cp.invert
26+
bitwise_right_shift = cp.right_shift
27+
concat = cp.concatenate
28+
pow = cp.power
29+
2830
arange = get_xp(cp)(_aliases.arange)
2931
empty = get_xp(cp)(_aliases.empty)
3032
empty_like = get_xp(cp)(_aliases.empty_like)

array_api_compat/numpy/_aliases.py

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -13,18 +13,20 @@
1313
import numpy as np
1414
bool = np.bool_
1515

16-
acos = get_xp(np)(_aliases.acos)
17-
acosh = get_xp(np)(_aliases.acosh)
18-
asin = get_xp(np)(_aliases.asin)
19-
asinh = get_xp(np)(_aliases.asinh)
20-
atan = get_xp(np)(_aliases.atan)
21-
atan2 = get_xp(np)(_aliases.atan2)
22-
atanh = get_xp(np)(_aliases.atanh)
23-
bitwise_left_shift = get_xp(np)(_aliases.bitwise_left_shift)
24-
bitwise_invert = get_xp(np)(_aliases.bitwise_invert)
25-
bitwise_right_shift = get_xp(np)(_aliases.bitwise_right_shift)
26-
concat = get_xp(np)(_aliases.concat)
27-
pow = get_xp(np)(_aliases.pow)
16+
# Basic renames
17+
acos = np.arccos
18+
acosh = np.arccosh
19+
asin = np.arcsin
20+
asinh = np.arcsinh
21+
atan = np.arctan
22+
atan2 = np.arctan2
23+
atanh = np.arctanh
24+
bitwise_left_shift = np.left_shift
25+
bitwise_invert = np.invert
26+
bitwise_right_shift = np.right_shift
27+
concat = np.concatenate
28+
pow = np.power
29+
2830
arange = get_xp(np)(_aliases.arange)
2931
empty = get_xp(np)(_aliases.empty)
3032
empty_like = get_xp(np)(_aliases.empty_like)

0 commit comments

Comments
 (0)