Skip to content

Commit 689a776

Browse files
committed
Respect the extension flag in linalg and fft
This behavior still needs to be tested. This required moving the linalg functions that are also in the main namespace so that they can still work there even when the linalg extension is disabled. The way I've decided to implement this is that the functions will not raise an exception until they are called. It would probably be more convenient for users if they raised an attribute error, or if the extension namespace itself did, like it would in a real library without the given extension. But the implementation for this would be a lot more complicated and didn't really feel worth it to me.
1 parent 4705b9f commit 689a776

File tree

5 files changed

+149
-54
lines changed

5 files changed

+149
-54
lines changed

array_api_strict/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -244,7 +244,7 @@
244244

245245
__all__ += ["linalg"]
246246

247-
from .linalg import matmul, tensordot, matrix_transpose, vecdot
247+
from ._linear_algebra_functions import matmul, tensordot, matrix_transpose, vecdot
248248

249249
__all__ += ["matmul", "tensordot", "matrix_transpose", "vecdot"]
250250

array_api_strict/_flags.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -273,3 +273,17 @@ def wrapper(*args, **kwargs):
273273
raise RuntimeError(f"The function {func.__name__} requires data-dependent shapes, but the data_dependent_shapes flag has been disabled for array-api-strict")
274274
return func(*args, **kwargs)
275275
return wrapper
276+
277+
def requires_extension(extension):
278+
def decorator(func):
279+
@functools.wraps(func)
280+
def wrapper(*args, **kwargs):
281+
if extension not in ENABLED_EXTENSIONS:
282+
if extension == 'linalg' \
283+
and func.__name__ in ['matmul', 'tensordot',
284+
'matrix_transpose', 'vecdot']:
285+
raise RuntimeError(f"The linalg extension has been disabled for array-api-strict. However, {func.__name__} is also present in the main array_api_strict namespace and may be used from there.")
286+
raise RuntimeError(f"The function {func.__name__} requires the {extension} extension, but it has been disabled for array-api-strict")
287+
return func(*args, **kwargs)
288+
return wrapper
289+
return decorator
Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
"""
2+
These functions are all also defined in the linalg extension, but we include
3+
them here with wrappers in linalg so that the wrappers can be disabled if the
4+
linalg extension is disabled in the flags.
5+
6+
"""
7+
8+
from __future__ import annotations
9+
10+
from ._dtypes import _numeric_dtypes
11+
12+
from ._array_object import Array
13+
14+
from typing import TYPE_CHECKING
15+
if TYPE_CHECKING:
16+
from ._typing import Sequence, Tuple, Union
17+
18+
import numpy.linalg
19+
import numpy as np
20+
21+
# Note: matmul is the numpy top-level namespace but not in np.linalg
22+
def matmul(x1: Array, x2: Array, /) -> Array:
23+
"""
24+
Array API compatible wrapper for :py:func:`np.matmul <numpy.matmul>`.
25+
26+
See its docstring for more information.
27+
"""
28+
# Note: the restriction to numeric dtypes only is different from
29+
# np.matmul.
30+
if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes:
31+
raise TypeError('Only numeric dtypes are allowed in matmul')
32+
33+
return Array._new(np.matmul(x1._array, x2._array))
34+
35+
# Note: tensordot is the numpy top-level namespace but not in np.linalg
36+
37+
# Note: axes must be a tuple, unlike np.tensordot where it can be an array or array-like.
38+
def tensordot(x1: Array, x2: Array, /, *, axes: Union[int, Tuple[Sequence[int], Sequence[int]]] = 2) -> Array:
39+
# Note: the restriction to numeric dtypes only is different from
40+
# np.tensordot.
41+
if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes:
42+
raise TypeError('Only numeric dtypes are allowed in tensordot')
43+
44+
return Array._new(np.tensordot(x1._array, x2._array, axes=axes))
45+
46+
# Note: this function is new in the array API spec. Unlike transpose, it only
47+
# transposes the last two axes.
48+
def matrix_transpose(x: Array, /) -> Array:
49+
if x.ndim < 2:
50+
raise ValueError("x must be at least 2-dimensional for matrix_transpose")
51+
return Array._new(np.swapaxes(x._array, -1, -2))
52+
53+
# Note: vecdot is not in NumPy
54+
def vecdot(x1: Array, x2: Array, /, *, axis: int = -1) -> Array:
55+
if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes:
56+
raise TypeError('Only numeric dtypes are allowed in vecdot')
57+
ndim = max(x1.ndim, x2.ndim)
58+
x1_shape = (1,)*(ndim - x1.ndim) + tuple(x1.shape)
59+
x2_shape = (1,)*(ndim - x2.ndim) + tuple(x2.shape)
60+
if x1_shape[axis] != x2_shape[axis]:
61+
raise ValueError("x1 and x2 must have the same size along the given axis")
62+
63+
x1_, x2_ = np.broadcast_arrays(x1._array, x2._array)
64+
x1_ = np.moveaxis(x1_, axis, -1)
65+
x2_ = np.moveaxis(x2_, axis, -1)
66+
67+
res = x1_[..., None, :] @ x2_[..., None]
68+
return Array._new(res[..., 0, 0])

array_api_strict/fft.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,11 @@
1515
)
1616
from ._array_object import Array, CPU_DEVICE
1717
from ._data_type_functions import astype
18+
from ._flags import requires_extension
1819

1920
import numpy as np
2021

22+
@requires_extension('fft')
2123
def fft(
2224
x: Array,
2325
/,
@@ -40,6 +42,7 @@ def fft(
4042
return astype(res, complex64)
4143
return res
4244

45+
@requires_extension('fft')
4346
def ifft(
4447
x: Array,
4548
/,
@@ -62,6 +65,7 @@ def ifft(
6265
return astype(res, complex64)
6366
return res
6467

68+
@requires_extension('fft')
6569
def fftn(
6670
x: Array,
6771
/,
@@ -84,6 +88,7 @@ def fftn(
8488
return astype(res, complex64)
8589
return res
8690

91+
@requires_extension('fft')
8792
def ifftn(
8893
x: Array,
8994
/,
@@ -106,6 +111,7 @@ def ifftn(
106111
return astype(res, complex64)
107112
return res
108113

114+
@requires_extension('fft')
109115
def rfft(
110116
x: Array,
111117
/,
@@ -128,6 +134,7 @@ def rfft(
128134
return astype(res, complex64)
129135
return res
130136

137+
@requires_extension('fft')
131138
def irfft(
132139
x: Array,
133140
/,
@@ -150,6 +157,7 @@ def irfft(
150157
return astype(res, float32)
151158
return res
152159

160+
@requires_extension('fft')
153161
def rfftn(
154162
x: Array,
155163
/,
@@ -172,6 +180,7 @@ def rfftn(
172180
return astype(res, complex64)
173181
return res
174182

183+
@requires_extension('fft')
175184
def irfftn(
176185
x: Array,
177186
/,
@@ -194,6 +203,7 @@ def irfftn(
194203
return astype(res, float32)
195204
return res
196205

206+
@requires_extension('fft')
197207
def hfft(
198208
x: Array,
199209
/,
@@ -216,6 +226,7 @@ def hfft(
216226
return astype(res, float32)
217227
return res
218228

229+
@requires_extension('fft')
219230
def ihfft(
220231
x: Array,
221232
/,
@@ -238,6 +249,7 @@ def ihfft(
238249
return astype(res, complex64)
239250
return res
240251

252+
@requires_extension('fft')
241253
def fftfreq(n: int, /, *, d: float = 1.0, device: Optional[Device] = None) -> Array:
242254
"""
243255
Array API compatible wrapper for :py:func:`np.fft.fftfreq <numpy.fft.fftfreq>`.
@@ -248,6 +260,7 @@ def fftfreq(n: int, /, *, d: float = 1.0, device: Optional[Device] = None) -> Ar
248260
raise ValueError(f"Unsupported device {device!r}")
249261
return Array._new(np.fft.fftfreq(n, d=d))
250262

263+
@requires_extension('fft')
251264
def rfftfreq(n: int, /, *, d: float = 1.0, device: Optional[Device] = None) -> Array:
252265
"""
253266
Array API compatible wrapper for :py:func:`np.fft.rfftfreq <numpy.fft.rfftfreq>`.
@@ -258,6 +271,7 @@ def rfftfreq(n: int, /, *, d: float = 1.0, device: Optional[Device] = None) -> A
258271
raise ValueError(f"Unsupported device {device!r}")
259272
return Array._new(np.fft.rfftfreq(n, d=d))
260273

274+
@requires_extension('fft')
261275
def fftshift(x: Array, /, *, axes: Union[int, Sequence[int]] = None) -> Array:
262276
"""
263277
Array API compatible wrapper for :py:func:`np.fft.fftshift <numpy.fft.fftshift>`.
@@ -268,6 +282,7 @@ def fftshift(x: Array, /, *, axes: Union[int, Sequence[int]] = None) -> Array:
268282
raise TypeError("Only floating-point dtypes are allowed in fftshift")
269283
return Array._new(np.fft.fftshift(x._array, axes=axes))
270284

285+
@requires_extension('fft')
271286
def ifftshift(x: Array, /, *, axes: Union[int, Sequence[int]] = None) -> Array:
272287
"""
273288
Array API compatible wrapper for :py:func:`np.fft.ifftshift <numpy.fft.ifftshift>`.

0 commit comments

Comments
 (0)