Skip to content

Commit f228b58

Browse files
committed
Add fft support for numpy and cupy
This is based off of numpy/numpy#25317
1 parent d235910 commit f228b58

File tree

5 files changed

+245
-0
lines changed

5 files changed

+245
-0
lines changed

array_api_compat/common/_fft.py

Lines changed: 183 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,183 @@
1+
from __future__ import annotations
2+
3+
from typing import TYPE_CHECKING, Union, Optional, Literal
4+
5+
if TYPE_CHECKING:
6+
from ._typing import Device, ndarray
7+
from collections.abc import Sequence
8+
9+
# Note: NumPy fft functions improperly upcast float32 and complex64 to
10+
# complex128, which is why we require wrapping them all here.
11+
12+
def fft(
13+
x: ndarray,
14+
/,
15+
xp,
16+
*,
17+
n: Optional[int] = None,
18+
axis: int = -1,
19+
norm: Literal["backward", "ortho", "forward"] = "backward",
20+
) -> ndarray:
21+
res = xp.fft.fft(x, n=n, axis=axis, norm=norm)
22+
if x.dtype in [xp.float32, xp.complex64]:
23+
return res.astype(xp.complex64)
24+
return res
25+
26+
def ifft(
27+
x: ndarray,
28+
/,
29+
xp,
30+
*,
31+
n: Optional[int] = None,
32+
axis: int = -1,
33+
norm: Literal["backward", "ortho", "forward"] = "backward",
34+
) -> ndarray:
35+
res = xp.fft.ifft(x, n=n, axis=axis, norm=norm)
36+
if x.dtype in [xp.float32, xp.complex64]:
37+
return res.astype(xp.complex64)
38+
return res
39+
40+
def fftn(
41+
x: ndarray,
42+
/,
43+
xp,
44+
*,
45+
s: Sequence[int] = None,
46+
axes: Sequence[int] = None,
47+
norm: Literal["backward", "ortho", "forward"] = "backward",
48+
) -> ndarray:
49+
res = xp.fft.fftn(x, s=s, axes=axes, norm=norm)
50+
if x.dtype in [xp.float32, xp.complex64]:
51+
return res.astype(xp.complex64)
52+
return res
53+
54+
def ifftn(
55+
x: ndarray,
56+
/,
57+
xp,
58+
*,
59+
s: Sequence[int] = None,
60+
axes: Sequence[int] = None,
61+
norm: Literal["backward", "ortho", "forward"] = "backward",
62+
) -> ndarray:
63+
res = xp.fft.ifftn(x, s=s, axes=axes, norm=norm)
64+
if x.dtype in [xp.float32, xp.complex64]:
65+
return res.astype(xp.complex64)
66+
return res
67+
68+
def rfft(
69+
x: ndarray,
70+
/,
71+
xp,
72+
*,
73+
n: Optional[int] = None,
74+
axis: int = -1,
75+
norm: Literal["backward", "ortho", "forward"] = "backward",
76+
) -> ndarray:
77+
res = xp.fft.rfft(x, n=n, axis=axis, norm=norm)
78+
if x.dtype == xp.float32:
79+
return res.astype(xp.complex64)
80+
return res
81+
82+
def irfft(
83+
x: ndarray,
84+
/,
85+
xp,
86+
*,
87+
n: Optional[int] = None,
88+
axis: int = -1,
89+
norm: Literal["backward", "ortho", "forward"] = "backward",
90+
) -> ndarray:
91+
res = xp.fft.irfft(x, n=n, axis=axis, norm=norm)
92+
if x.dtype == xp.complex64:
93+
return res.astype(xp.float32)
94+
return res
95+
96+
def rfftn(
97+
x: ndarray,
98+
/,
99+
xp,
100+
*,
101+
s: Sequence[int] = None,
102+
axes: Sequence[int] = None,
103+
norm: Literal["backward", "ortho", "forward"] = "backward",
104+
) -> ndarray:
105+
res = xp.fft.rfftn(x, s=s, axes=axes, norm=norm)
106+
if x.dtype == xp.float32:
107+
return res.astype(xp.complex64)
108+
return res
109+
110+
def irfftn(
111+
x: ndarray,
112+
/,
113+
xp,
114+
*,
115+
s: Sequence[int] = None,
116+
axes: Sequence[int] = None,
117+
norm: Literal["backward", "ortho", "forward"] = "backward",
118+
) -> ndarray:
119+
res = xp.fft.irfftn(x, s=s, axes=axes, norm=norm)
120+
if x.dtype == xp.complex64:
121+
return res.astype(xp.float32)
122+
return res
123+
124+
def hfft(
125+
x: ndarray,
126+
/,
127+
xp,
128+
*,
129+
n: Optional[int] = None,
130+
axis: int = -1,
131+
norm: Literal["backward", "ortho", "forward"] = "backward",
132+
) -> ndarray:
133+
res = xp.fft.hfft(x, n=n, axis=axis, norm=norm)
134+
if x.dtype in [xp.float32, xp.complex64]:
135+
return res.astype(xp.complex64)
136+
return res
137+
138+
def ihfft(
139+
x: ndarray,
140+
/,
141+
xp,
142+
*,
143+
n: Optional[int] = None,
144+
axis: int = -1,
145+
norm: Literal["backward", "ortho", "forward"] = "backward",
146+
) -> ndarray:
147+
res = xp.fft.ihfft(x, n=n, axis=axis, norm=norm)
148+
if x.dtype in [xp.float32, xp.complex64]:
149+
return res.astype(xp.complex64)
150+
return res
151+
152+
def fftfreq(n: int, /, xp, *, d: float = 1.0, device: Optional[Device] = None) -> ndarray:
153+
if device not in ["cpu", None]:
154+
raise ValueError(f"Unsupported device {device!r}")
155+
return xp.fft.fftfreq(n, d=d)
156+
157+
def rfftfreq(n: int, /, xp, *, d: float = 1.0, device: Optional[Device] = None) -> ndarray:
158+
if device not in ["cpu", None]:
159+
raise ValueError(f"Unsupported device {device!r}")
160+
return xp.fft.rfftfreq(n, d=d)
161+
162+
def fftshift(x: ndarray, /, xp, *, axes: Union[int, Sequence[int]] = None) -> ndarray:
163+
return xp.fft.fftshift(x, axes=axes)
164+
165+
def ifftshift(x: ndarray, /, xp, *, axes: Union[int, Sequence[int]] = None) -> ndarray:
166+
return xp.fft.ifftshift(x, axes=axes)
167+
168+
__all__ = [
169+
"fft",
170+
"ifft",
171+
"fftn",
172+
"ifftn",
173+
"rfft",
174+
"irfft",
175+
"rfftn",
176+
"irfftn",
177+
"hfft",
178+
"ihfft",
179+
"fftfreq",
180+
"rfftfreq",
181+
"fftshift",
182+
"ifftshift",
183+
]

array_api_compat/cupy/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
# See the comment in the numpy __init__.py
1010
__import__(__package__ + '.linalg')
1111

12+
__import__(__package__ + '.fft')
13+
1214
from .linalg import matrix_transpose, vecdot
1315

1416
from ..common._helpers import *

array_api_compat/cupy/fft.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
from cupy.fft import *
2+
from cupy.fft import __all__ as fft_all
3+
4+
from ..common import _fft
5+
from .._internal import get_xp
6+
7+
import cupy as cp
8+
9+
fft = get_xp(cp)(_fft.fft),
10+
ifft = get_xp(cp)(_fft.ifft),
11+
fftn = get_xp(cp)(_fft.fftn),
12+
ifftn = get_xp(cp)(_fft.ifftn),
13+
rfft = get_xp(cp)(_fft.rfft),
14+
irfft = get_xp(cp)(_fft.irfft),
15+
rfftn = get_xp(cp)(_fft.rfftn),
16+
irfftn = get_xp(cp)(_fft.irfftn),
17+
hfft = get_xp(cp)(_fft.hfft),
18+
ihfft = get_xp(cp)(_fft.ihfft),
19+
fftfreq = get_xp(cp)(_fft.fftfreq),
20+
rfftfreq = get_xp(cp)(_fft.rfftfreq),
21+
fftshift = get_xp(cp)(_fft.fftshift),
22+
ifftshift = get_xp(cp)(_fft.ifftshift),
23+
24+
__all__ = fft_all + _fft.__all__
25+
26+
del get_xp
27+
del cp
28+
del fft_all
29+
del _fft

array_api_compat/numpy/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
# dynamically so that the library can be vendored.
1616
__import__(__package__ + '.linalg')
1717

18+
__import__(__package__ + '.fft')
19+
1820
from .linalg import matrix_transpose, vecdot
1921

2022
from ..common._helpers import *

array_api_compat/numpy/fft.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
from numpy.fft import *
2+
from numpy.fft import __all__ as fft_all
3+
4+
from ..common import _fft
5+
from .._internal import get_xp
6+
7+
import numpy as np
8+
9+
fft = get_xp(np)(_fft.fft)
10+
ifft = get_xp(np)(_fft.ifft)
11+
fftn = get_xp(np)(_fft.fftn)
12+
ifftn = get_xp(np)(_fft.ifftn)
13+
rfft = get_xp(np)(_fft.rfft)
14+
irfft = get_xp(np)(_fft.irfft)
15+
rfftn = get_xp(np)(_fft.rfftn)
16+
irfftn = get_xp(np)(_fft.irfftn)
17+
hfft = get_xp(np)(_fft.hfft)
18+
ihfft = get_xp(np)(_fft.ihfft)
19+
fftfreq = get_xp(np)(_fft.fftfreq)
20+
rfftfreq = get_xp(np)(_fft.rfftfreq)
21+
fftshift = get_xp(np)(_fft.fftshift)
22+
ifftshift = get_xp(np)(_fft.ifftshift)
23+
24+
__all__ = fft_all + _fft.__all__
25+
26+
del get_xp
27+
del np
28+
del fft_all
29+
del _fft

0 commit comments

Comments
 (0)