@@ -54,24 +54,56 @@ def extract_sequences(x, sequence_length, sequence_stride):
54
54
return x .reshape (* batch_shape , frames , sequence_length )
55
55
56
56
57
+ def _get_complex_tensor_from_tuple (x ):
58
+ if not isinstance (x , (tuple , list )) or len (x ) != 2 :
59
+ raise ValueError (
60
+ "Input `x` should be a tuple of two tensors - real and imaginary."
61
+ f"Received: x={ x } "
62
+ )
63
+ real , imag = x
64
+ real = convert_to_tensor (real )
65
+ imag = convert_to_tensor (imag )
66
+ # Check shapes.
67
+ if real .shape != imag .shape :
68
+ raise ValueError (
69
+ "Input `x` should be a tuple of two tensors - real and imaginary."
70
+ "Both the real and imaginary parts should have the same shape. "
71
+ f"Received: x[0].shape = { real .shape } , x[1].shape = { imag .shape } "
72
+ )
73
+ # Ensure dtype is float.
74
+ if not mx .issubdtype (real .dtype , mx .floating ) or not mx .issubdtype (
75
+ imag .dtype , mx .floating
76
+ ):
77
+ raise ValueError (
78
+ "At least one tensor in input `x` is not of type float."
79
+ f"Received: x={ x } ."
80
+ )
81
+ complex_input = mx .add (real , 1j * imag )
82
+ return complex_input
83
+
84
+
57
85
def fft (x ):
58
- x = convert_to_tensor (x )
59
- return mx .fft (x )
86
+ x = _get_complex_tensor_from_tuple (x )
87
+ complex_output = mx .fft .fft (x )
88
+ return mx .real (complex_output ), mx .imag (complex_output )
60
89
61
90
62
91
def fft2 (x ):
63
- # TODO: https://ml-explore.github.io/mlx/build/html/python/fft.html#fft
64
- raise NotImplementedError ("fft not yet implemented in mlx" )
92
+ x = _get_complex_tensor_from_tuple (x )
93
+ complex_output = mx .fft .fft2 (x )
94
+ return mx .real (complex_output ), mx .imag (complex_output )
65
95
66
96
67
97
def rfft (x , fft_length = None ):
68
- # TODO: https://ml-explore.github.io/mlx/build/html/python/fft.html#fft
69
- raise NotImplementedError ("fft not yet implemented in mlx" )
98
+ x = convert_to_tensor (x )
99
+ complex_output = mx .fft .rfft (x , n = fft_length )
100
+ return mx .real (complex_output ), mx .imag (complex_output )
70
101
71
102
72
103
def irfft (x , fft_length = None ):
73
- # TODO: https://ml-explore.github.io/mlx/build/html/python/fft.html#fft
74
- raise NotImplementedError ("fft not yet implemented in mlx" )
104
+ x = _get_complex_tensor_from_tuple (x )
105
+ real_output = mx .fft .irfft (x , n = fft_length )
106
+ return real_output
75
107
76
108
77
109
def stft (
0 commit comments