3
3
using AbstractFFTs
4
4
using AbstractFFTs: Plan
5
5
using ChainRulesTestUtils
6
+ using ChainRulesCore: NoTangent
6
7
7
8
using LinearAlgebra
8
9
using Random
197
198
@test @inferred (f9 (plan_fft (zeros (10 ), 1 ), 10 )) == 1 / 10
198
199
end
199
200
201
+ @testset " output size" begin
202
+ @testset " complex fft output size" begin
203
+ for x in (randn (3 ), randn (3 , 4 ), randn (3 , 4 , 5 ))
204
+ N = ndims (x)
205
+ y = randn (size (x))
206
+ for dims in unique ((1 , 1 : N, N))
207
+ P = plan_fft (x, dims)
208
+ @test AbstractFFTs. output_size (P) == size (x)
209
+ @test AbstractFFTs. output_size (P' ) == size (x)
210
+ Pinv = plan_ifft (x)
211
+ @test AbstractFFTs. output_size (Pinv) == size (x)
212
+ @test AbstractFFTs. output_size (Pinv' ) == size (x)
213
+ end
214
+ end
215
+ end
216
+ @testset " real fft output size" begin
217
+ for x in (randn (3 ), randn (4 ), randn (3 , 4 ), randn (3 , 4 , 5 )) # test odd and even lengths
218
+ N = ndims (x)
219
+ for dims in unique ((1 , 1 : N, N))
220
+ P = plan_rfft (x, dims)
221
+ Px_sz = size (P * x)
222
+ @test AbstractFFTs. output_size (P) == Px_sz
223
+ @test AbstractFFTs. output_size (P' ) == size (x)
224
+ y = randn (Px_sz) .+ randn (Px_sz) * im
225
+ Pinv = plan_irfft (y, size (x)[first (dims)], dims)
226
+ @test AbstractFFTs. output_size (Pinv) == size (Pinv * y)
227
+ @test AbstractFFTs. output_size (Pinv' ) == size (y)
228
+ end
229
+ end
230
+ end
231
+ end
232
+
233
+ @testset " adjoint" begin
234
+ @testset " complex fft adjoint" begin
235
+ for x in (randn (3 ), randn (3 , 4 ), randn (3 , 4 , 5 ))
236
+ N = ndims (x)
237
+ y = randn (size (x))
238
+ for dims in unique ((1 , 1 : N, N))
239
+ P = plan_fft (x, dims)
240
+ @test (P' )' * x == P * x # test adjoint of adjoint
241
+ @test size (P' ) == AbstractFFTs. output_size (P) # test size of adjoint
242
+ @test dot (y, P * x) ≈ dot (P' * y, x) # test validity of adjoint
243
+ @test_broken dot (y, P \ x) ≈ dot (P' \ y, x)
244
+ Pinv = plan_ifft (y)
245
+ @test (Pinv' )' * y == Pinv * y
246
+ @test size (Pinv' ) == AbstractFFTs. output_size (Pinv)
247
+ @test dot (x, Pinv * y) ≈ dot (Pinv' * x, y)
248
+ @test_broken dot (x, Pinv \ y) ≈ dot (Pinv' \ x, y)
249
+ end
250
+ end
251
+ end
252
+ @testset " real fft adjoint" begin
253
+ for x in (randn (3 ), randn (4 ), randn (3 , 4 ), randn (3 , 4 , 5 )) # test odd and even lengths
254
+ N = ndims (x)
255
+ for dims in unique ((1 , 1 : N, N))
256
+ P = plan_rfft (x, dims)
257
+ y_real = randn (size (P * x))
258
+ y_imag = randn (size (P * x))
259
+ y = y_real .+ y_imag .* im
260
+ @test (P' )' * x == P * x
261
+ @test size (P' ) == AbstractFFTs. output_size (P)
262
+ @test dot (y_real, real .(P * x)) + dot (y_imag, imag .(P * x)) ≈ dot (P' * y, x)
263
+ @test_broken dot (y_real, real .(P \ x)) + dot (y_imag, imag .(P \ x)) ≈ dot (P' * y, x)
264
+ Pinv = plan_irfft (y, size (x)[first (dims)], dims)
265
+ @test (Pinv' )' * y == Pinv * y
266
+ @test size (Pinv' ) == AbstractFFTs. output_size (Pinv)
267
+ @test dot (x, Pinv * y) ≈ dot (y_real, real .(Pinv' * x)) + dot (y_imag, imag .(Pinv' * x))
268
+ @test_broken dot (x, Pinv \ y) ≈ dot (y_real, real .(Pinv' \ x)) + dot (y_imag, imag .(Pinv' \ x))
269
+ end
270
+ end
271
+ end
272
+ end
273
+
200
274
@testset " ChainRules" begin
201
275
@testset " shift functions" begin
202
276
for x in (randn (3 ), randn (3 , 4 ), randn (3 , 4 , 5 ))
@@ -218,20 +292,31 @@ end
218
292
end
219
293
220
294
@testset " fft" begin
221
- for x in (randn (3 ), randn (3 , 4 ), randn (3 , 4 , 5 ))
295
+ for x in (randn (2 ), randn (2 , 3 ), randn (3 , 4 , 5 ))
222
296
N = ndims (x)
223
297
complex_x = complex .(x)
224
298
for dims in unique ((1 , 1 : N, N))
299
+ # fft, ifft, bfft
225
300
for f in (fft, ifft, bfft)
226
301
test_frule (f, x, dims)
227
302
test_rrule (f, x, dims)
228
303
test_frule (f, complex_x, dims)
229
304
test_rrule (f, complex_x, dims)
230
305
end
306
+ for pf in (plan_fft, plan_ifft, plan_bfft)
307
+ test_frule (* , pf (x, dims) ⊢ NoTangent (), x)
308
+ test_rrule (* , pf (x, dims) ⊢ NoTangent (), x)
309
+ test_frule (* , pf (complex_x, dims) ⊢ NoTangent (), complex_x)
310
+ test_rrule (* , pf (complex_x, dims) ⊢ NoTangent (), complex_x)
311
+ end
231
312
313
+ # rfft
232
314
test_frule (rfft, x, dims)
233
315
test_rrule (rfft, x, dims)
316
+ test_frule (* , plan_rfft (x, dims) ⊢ NoTangent (), x)
317
+ test_rrule (* , plan_rfft (x, dims) ⊢ NoTangent (), x)
234
318
319
+ # irfft, brfft
235
320
for f in (irfft, brfft)
236
321
for d in (2 * size (x, first (dims)) - 1 , 2 * size (x, first (dims)) - 2 )
237
322
test_frule (f, x, d, dims)
240
325
test_rrule (f, complex_x, d, dims)
241
326
end
242
327
end
328
+ for pf in (plan_irfft, plan_brfft)
329
+ for d in (2 * size (x, first (dims)) - 1 , 2 * size (x, first (dims)) - 2 )
330
+ test_frule (* , pf (complex_x, d, dims) ⊢ NoTangent (), complex_x)
331
+ test_rrule (* , pf (complex_x, d, dims) ⊢ NoTangent (), complex_x)
332
+ end
333
+ end
243
334
end
244
335
end
245
336
end
0 commit comments