Skip to content

Commit 9b75b86

Browse files
authored
fix to fft, implement fft2, rfft, and irfft for mlx (#20781)
1 parent 603affa commit 9b75b86

File tree

1 file changed

+40
-8
lines changed

1 file changed

+40
-8
lines changed

keras/src/backend/mlx/math.py

Lines changed: 40 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -54,24 +54,56 @@ def extract_sequences(x, sequence_length, sequence_stride):
5454
return x.reshape(*batch_shape, frames, sequence_length)
5555

5656

57+
def _get_complex_tensor_from_tuple(x):
58+
if not isinstance(x, (tuple, list)) or len(x) != 2:
59+
raise ValueError(
60+
"Input `x` should be a tuple of two tensors - real and imaginary."
61+
f"Received: x={x}"
62+
)
63+
real, imag = x
64+
real = convert_to_tensor(real)
65+
imag = convert_to_tensor(imag)
66+
# Check shapes.
67+
if real.shape != imag.shape:
68+
raise ValueError(
69+
"Input `x` should be a tuple of two tensors - real and imaginary."
70+
"Both the real and imaginary parts should have the same shape. "
71+
f"Received: x[0].shape = {real.shape}, x[1].shape = {imag.shape}"
72+
)
73+
# Ensure dtype is float.
74+
if not mx.issubdtype(real.dtype, mx.floating) or not mx.issubdtype(
75+
imag.dtype, mx.floating
76+
):
77+
raise ValueError(
78+
"At least one tensor in input `x` is not of type float."
79+
f"Received: x={x}."
80+
)
81+
complex_input = mx.add(real, 1j * imag)
82+
return complex_input
83+
84+
5785
def fft(x):
58-
x = convert_to_tensor(x)
59-
return mx.fft(x)
86+
x = _get_complex_tensor_from_tuple(x)
87+
complex_output = mx.fft.fft(x)
88+
return mx.real(complex_output), mx.imag(complex_output)
6089

6190

6291
def fft2(x):
63-
# TODO: https://ml-explore.github.io/mlx/build/html/python/fft.html#fft
64-
raise NotImplementedError("fft not yet implemented in mlx")
92+
x = _get_complex_tensor_from_tuple(x)
93+
complex_output = mx.fft.fft2(x)
94+
return mx.real(complex_output), mx.imag(complex_output)
6595

6696

6797
def rfft(x, fft_length=None):
68-
# TODO: https://ml-explore.github.io/mlx/build/html/python/fft.html#fft
69-
raise NotImplementedError("fft not yet implemented in mlx")
98+
x = convert_to_tensor(x)
99+
complex_output = mx.fft.rfft(x, n=fft_length)
100+
return mx.real(complex_output), mx.imag(complex_output)
70101

71102

72103
def irfft(x, fft_length=None):
73-
# TODO: https://ml-explore.github.io/mlx/build/html/python/fft.html#fft
74-
raise NotImplementedError("fft not yet implemented in mlx")
104+
x = _get_complex_tensor_from_tuple(x)
105+
real_output = mx.fft.irfft(x, n=fft_length)
106+
return real_output
75107

76108

77109
def stft(

0 commit comments

Comments
 (0)