Skip to content

Commit d197e5e

Browse files
authored
ENH: Add fft optional extension submodule to numpy.array_api (#25317)
Original NumPy Commit: 8e8da65678ac85b3c2756935d7c66c8b3203ca92
1 parent 68cdfaf commit d197e5e

File tree

2 files changed

+299
-0
lines changed

2 files changed

+299
-0
lines changed

array_api_strict/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -353,6 +353,9 @@
353353

354354
__all__ += ["matmul", "tensordot", "matrix_transpose", "vecdot"]
355355

356+
from . import fft
357+
__all__ += ["fft"]
358+
356359
from ._manipulation_functions import (
357360
concat,
358361
expand_dims,

array_api_strict/fft.py

Lines changed: 296 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,296 @@
1+
from __future__ import annotations
2+
3+
from typing import TYPE_CHECKING, Union, Optional, Literal
4+
5+
if TYPE_CHECKING:
6+
from ._typing import Device
7+
from collections.abc import Sequence
8+
9+
from ._dtypes import (
10+
_floating_dtypes,
11+
_real_floating_dtypes,
12+
_complex_floating_dtypes,
13+
float32,
14+
complex64,
15+
)
16+
from ._array_object import Array
17+
from ._data_type_functions import astype
18+
19+
import numpy as np
20+
21+
def fft(
22+
x: Array,
23+
/,
24+
*,
25+
n: Optional[int] = None,
26+
axis: int = -1,
27+
norm: Literal["backward", "ortho", "forward"] = "backward",
28+
) -> Array:
29+
"""
30+
Array API compatible wrapper for :py:func:`np.fft.fft <numpy.fft.fft>`.
31+
32+
See its docstring for more information.
33+
"""
34+
if x.dtype not in _complex_floating_dtypes:
35+
raise TypeError("Only complex floating-point dtypes are allowed in fft")
36+
res = Array._new(np.fft.fft(x._array, n=n, axis=axis, norm=norm))
37+
# Note: np.fft functions improperly upcast float32 and complex64 to
38+
# complex128
39+
if x.dtype == complex64:
40+
return astype(res, complex64)
41+
return res
42+
43+
def ifft(
44+
x: Array,
45+
/,
46+
*,
47+
n: Optional[int] = None,
48+
axis: int = -1,
49+
norm: Literal["backward", "ortho", "forward"] = "backward",
50+
) -> Array:
51+
"""
52+
Array API compatible wrapper for :py:func:`np.fft.ifft <numpy.fft.ifft>`.
53+
54+
See its docstring for more information.
55+
"""
56+
if x.dtype not in _complex_floating_dtypes:
57+
raise TypeError("Only complex floating-point dtypes are allowed in ifft")
58+
res = Array._new(np.fft.ifft(x._array, n=n, axis=axis, norm=norm))
59+
# Note: np.fft functions improperly upcast float32 and complex64 to
60+
# complex128
61+
if x.dtype == complex64:
62+
return astype(res, complex64)
63+
return res
64+
65+
def fftn(
66+
x: Array,
67+
/,
68+
*,
69+
s: Sequence[int] = None,
70+
axes: Sequence[int] = None,
71+
norm: Literal["backward", "ortho", "forward"] = "backward",
72+
) -> Array:
73+
"""
74+
Array API compatible wrapper for :py:func:`np.fft.fftn <numpy.fft.fftn>`.
75+
76+
See its docstring for more information.
77+
"""
78+
if x.dtype not in _complex_floating_dtypes:
79+
raise TypeError("Only complex floating-point dtypes are allowed in fftn")
80+
res = Array._new(np.fft.fftn(x._array, s=s, axes=axes, norm=norm))
81+
# Note: np.fft functions improperly upcast float32 and complex64 to
82+
# complex128
83+
if x.dtype == complex64:
84+
return astype(res, complex64)
85+
return res
86+
87+
def ifftn(
88+
x: Array,
89+
/,
90+
*,
91+
s: Sequence[int] = None,
92+
axes: Sequence[int] = None,
93+
norm: Literal["backward", "ortho", "forward"] = "backward",
94+
) -> Array:
95+
"""
96+
Array API compatible wrapper for :py:func:`np.fft.ifftn <numpy.fft.ifftn>`.
97+
98+
See its docstring for more information.
99+
"""
100+
if x.dtype not in _complex_floating_dtypes:
101+
raise TypeError("Only complex floating-point dtypes are allowed in ifftn")
102+
res = Array._new(np.fft.ifftn(x._array, s=s, axes=axes, norm=norm))
103+
# Note: np.fft functions improperly upcast float32 and complex64 to
104+
# complex128
105+
if x.dtype == complex64:
106+
return astype(res, complex64)
107+
return res
108+
109+
def rfft(
110+
x: Array,
111+
/,
112+
*,
113+
n: Optional[int] = None,
114+
axis: int = -1,
115+
norm: Literal["backward", "ortho", "forward"] = "backward",
116+
) -> Array:
117+
"""
118+
Array API compatible wrapper for :py:func:`np.fft.rfft <numpy.fft.rfft>`.
119+
120+
See its docstring for more information.
121+
"""
122+
if x.dtype not in _real_floating_dtypes:
123+
raise TypeError("Only real floating-point dtypes are allowed in rfft")
124+
res = Array._new(np.fft.rfft(x._array, n=n, axis=axis, norm=norm))
125+
# Note: np.fft functions improperly upcast float32 and complex64 to
126+
# complex128
127+
if x.dtype == float32:
128+
return astype(res, complex64)
129+
return res
130+
131+
def irfft(
132+
x: Array,
133+
/,
134+
*,
135+
n: Optional[int] = None,
136+
axis: int = -1,
137+
norm: Literal["backward", "ortho", "forward"] = "backward",
138+
) -> Array:
139+
"""
140+
Array API compatible wrapper for :py:func:`np.fft.irfft <numpy.fft.irfft>`.
141+
142+
See its docstring for more information.
143+
"""
144+
if x.dtype not in _complex_floating_dtypes:
145+
raise TypeError("Only complex floating-point dtypes are allowed in irfft")
146+
res = Array._new(np.fft.irfft(x._array, n=n, axis=axis, norm=norm))
147+
# Note: np.fft functions improperly upcast float32 and complex64 to
148+
# complex128
149+
if x.dtype == complex64:
150+
return astype(res, float32)
151+
return res
152+
153+
def rfftn(
154+
x: Array,
155+
/,
156+
*,
157+
s: Sequence[int] = None,
158+
axes: Sequence[int] = None,
159+
norm: Literal["backward", "ortho", "forward"] = "backward",
160+
) -> Array:
161+
"""
162+
Array API compatible wrapper for :py:func:`np.fft.rfftn <numpy.fft.rfftn>`.
163+
164+
See its docstring for more information.
165+
"""
166+
if x.dtype not in _real_floating_dtypes:
167+
raise TypeError("Only real floating-point dtypes are allowed in rfftn")
168+
res = Array._new(np.fft.rfftn(x._array, s=s, axes=axes, norm=norm))
169+
# Note: np.fft functions improperly upcast float32 and complex64 to
170+
# complex128
171+
if x.dtype == float32:
172+
return astype(res, complex64)
173+
return res
174+
175+
def irfftn(
176+
x: Array,
177+
/,
178+
*,
179+
s: Sequence[int] = None,
180+
axes: Sequence[int] = None,
181+
norm: Literal["backward", "ortho", "forward"] = "backward",
182+
) -> Array:
183+
"""
184+
Array API compatible wrapper for :py:func:`np.fft.irfftn <numpy.fft.irfftn>`.
185+
186+
See its docstring for more information.
187+
"""
188+
if x.dtype not in _complex_floating_dtypes:
189+
raise TypeError("Only complex floating-point dtypes are allowed in irfftn")
190+
res = Array._new(np.fft.irfftn(x._array, s=s, axes=axes, norm=norm))
191+
# Note: np.fft functions improperly upcast float32 and complex64 to
192+
# complex128
193+
if x.dtype == complex64:
194+
return astype(res, float32)
195+
return res
196+
197+
def hfft(
198+
x: Array,
199+
/,
200+
*,
201+
n: Optional[int] = None,
202+
axis: int = -1,
203+
norm: Literal["backward", "ortho", "forward"] = "backward",
204+
) -> Array:
205+
"""
206+
Array API compatible wrapper for :py:func:`np.fft.hfft <numpy.fft.hfft>`.
207+
208+
See its docstring for more information.
209+
"""
210+
if x.dtype not in _complex_floating_dtypes:
211+
raise TypeError("Only complex floating-point dtypes are allowed in hfft")
212+
res = Array._new(np.fft.hfft(x._array, n=n, axis=axis, norm=norm))
213+
# Note: np.fft functions improperly upcast float32 and complex64 to
214+
# complex128
215+
if x.dtype == complex64:
216+
return astype(res, float32)
217+
return res
218+
219+
def ihfft(
220+
x: Array,
221+
/,
222+
*,
223+
n: Optional[int] = None,
224+
axis: int = -1,
225+
norm: Literal["backward", "ortho", "forward"] = "backward",
226+
) -> Array:
227+
"""
228+
Array API compatible wrapper for :py:func:`np.fft.ihfft <numpy.fft.ihfft>`.
229+
230+
See its docstring for more information.
231+
"""
232+
if x.dtype not in _real_floating_dtypes:
233+
raise TypeError("Only real floating-point dtypes are allowed in ihfft")
234+
res = Array._new(np.fft.ihfft(x._array, n=n, axis=axis, norm=norm))
235+
# Note: np.fft functions improperly upcast float32 and complex64 to
236+
# complex128
237+
if x.dtype == float32:
238+
return astype(res, complex64)
239+
return res
240+
241+
def fftfreq(n: int, /, *, d: float = 1.0, device: Optional[Device] = None) -> Array:
242+
"""
243+
Array API compatible wrapper for :py:func:`np.fft.fftfreq <numpy.fft.fftfreq>`.
244+
245+
See its docstring for more information.
246+
"""
247+
if device not in ["cpu", None]:
248+
raise ValueError(f"Unsupported device {device!r}")
249+
return Array._new(np.fft.fftfreq(n, d=d))
250+
251+
def rfftfreq(n: int, /, *, d: float = 1.0, device: Optional[Device] = None) -> Array:
252+
"""
253+
Array API compatible wrapper for :py:func:`np.fft.rfftfreq <numpy.fft.rfftfreq>`.
254+
255+
See its docstring for more information.
256+
"""
257+
if device not in ["cpu", None]:
258+
raise ValueError(f"Unsupported device {device!r}")
259+
return Array._new(np.fft.rfftfreq(n, d=d))
260+
261+
def fftshift(x: Array, /, *, axes: Union[int, Sequence[int]] = None) -> Array:
262+
"""
263+
Array API compatible wrapper for :py:func:`np.fft.fftshift <numpy.fft.fftshift>`.
264+
265+
See its docstring for more information.
266+
"""
267+
if x.dtype not in _floating_dtypes:
268+
raise TypeError("Only floating-point dtypes are allowed in fftshift")
269+
return Array._new(np.fft.fftshift(x._array, axes=axes))
270+
271+
def ifftshift(x: Array, /, *, axes: Union[int, Sequence[int]] = None) -> Array:
272+
"""
273+
Array API compatible wrapper for :py:func:`np.fft.ifftshift <numpy.fft.ifftshift>`.
274+
275+
See its docstring for more information.
276+
"""
277+
if x.dtype not in _floating_dtypes:
278+
raise TypeError("Only floating-point dtypes are allowed in ifftshift")
279+
return Array._new(np.fft.ifftshift(x._array, axes=axes))
280+
281+
__all__ = [
282+
"fft",
283+
"ifft",
284+
"fftn",
285+
"ifftn",
286+
"rfft",
287+
"irfft",
288+
"rfftn",
289+
"irfftn",
290+
"hfft",
291+
"ihfft",
292+
"fftfreq",
293+
"rfftfreq",
294+
"fftshift",
295+
"ifftshift",
296+
]

0 commit comments

Comments
 (0)