Skip to content

Commit 1ea7ecd

Browse files
committed
Add wrappers for torch.fft
The only thing that needs to be wrapped is a few functions which do not properly map axes to dim.
1 parent f97b59e commit 1ea7ecd

File tree

2 files changed

+86
-0
lines changed

2 files changed

+86
-0
lines changed

array_api_compat/torch/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
# See the comment in the numpy __init__.py
1818
__import__(__package__ + '.linalg')
1919

20+
__import__(__package__ + '.fft')
21+
2022
from ..common._helpers import * # noqa: F403
2123

2224
__array_api_version__ = '2022.12'

array_api_compat/torch/fft.py

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
from __future__ import annotations
2+
3+
from typing import TYPE_CHECKING
4+
if TYPE_CHECKING:
5+
import torch
6+
array = torch.Tensor
7+
from typing import Union, Sequence, Literal
8+
9+
from torch.fft import * # noqa: F403
10+
import torch.fft
11+
12+
# Several torch fft functions do not map axes to dim
13+
14+
def fftn(
15+
x: array,
16+
/,
17+
*,
18+
s: Sequence[int] = None,
19+
axes: Sequence[int] = None,
20+
norm: Literal["backward", "ortho", "forward"] = "backward",
21+
**kwargs,
22+
) -> array:
23+
return torch.fft.fftn(x, s=s, dim=axes, norm=norm, **kwargs)
24+
25+
def ifftn(
26+
x: array,
27+
/,
28+
*,
29+
s: Sequence[int] = None,
30+
axes: Sequence[int] = None,
31+
norm: Literal["backward", "ortho", "forward"] = "backward",
32+
**kwargs,
33+
) -> array:
34+
return torch.fft.ifftn(x, s=s, dim=axes, norm=norm, **kwargs)
35+
36+
def rfftn(
37+
x: array,
38+
/,
39+
*,
40+
s: Sequence[int] = None,
41+
axes: Sequence[int] = None,
42+
norm: Literal["backward", "ortho", "forward"] = "backward",
43+
**kwargs,
44+
) -> array:
45+
return torch.fft.rfftn(x, s=s, dim=axes, norm=norm, **kwargs)
46+
47+
def irfftn(
48+
x: array,
49+
/,
50+
*,
51+
s: Sequence[int] = None,
52+
axes: Sequence[int] = None,
53+
norm: Literal["backward", "ortho", "forward"] = "backward",
54+
**kwargs,
55+
) -> array:
56+
return torch.fft.irfftn(x, s=s, dim=axes, norm=norm, **kwargs)
57+
58+
def fftshift(
59+
x: array,
60+
/,
61+
*,
62+
axes: Union[int, Sequence[int]] = None,
63+
**kwargs,
64+
) -> array:
65+
return torch.fft.fftshift(x, dim=axes, **kwargs)
66+
67+
def ifftshift(
68+
x: array,
69+
/,
70+
*,
71+
axes: Union[int, Sequence[int]] = None,
72+
**kwargs,
73+
) -> array:
74+
return torch.fft.ifftshift(x, dim=axes, **kwargs)
75+
76+
77+
__all__ = torch.fft.__all__ + [
78+
"fftn",
79+
"ifftn",
80+
"rfftn",
81+
"irfftn",
82+
"fftshift",
83+
"ifftshift",
84+
]

0 commit comments

Comments
 (0)