1
1
# This file contains code that was formerly part of Julia. License is MIT: https://julialang.org/license
2
2
3
3
using AbstractFFTs
4
- using AbstractFFTs: Plan
4
+ using AbstractFFTs: Plan, ScaledPlan
5
5
using ChainRulesTestUtils
6
- using ChainRulesCore: NoTangent
6
+ using FiniteDifferences
7
+ import ChainRulesCore
7
8
8
9
using LinearAlgebra
9
10
using Random
293
294
end
294
295
295
296
@testset " fft" begin
296
- for x in (randn (2 ), randn (2 , 3 ), randn (3 , 4 , 5 ))
297
- N = ndims (x)
298
- complex_x = complex .(x)
297
+ # Overloads to allow ChainRulesTestUtils to test rules w.r.t. ScaledPlan's. See https://github.com/JuliaDiff/ChainRulesTestUtils.jl/issues/256
298
+ InnerPlan = Union{TestPlan, InverseTestPlan, TestRPlan, InverseTestRPlan}
299
+ function FiniteDifferences. to_vec (x:: InnerPlan )
300
+ function FFTPlan_from_vec (x_vec:: Vector )
301
+ return x
302
+ end
303
+ return Bool[], FFTPlan_from_vec
304
+ end
305
+ ChainRulesTestUtils. test_approx (:: ChainRulesCore.AbstractZero , x:: InnerPlan , msg= " " ; kwargs... ) = true
306
+ ChainRulesTestUtils. rand_tangent (:: AbstractRNG , x:: InnerPlan ) = ChainRulesCore. NoTangent ()
307
+
308
+ for x_shape in ((2 ,), (2 , 3 ), (3 , 4 , 5 ))
309
+ N = length (x_shape)
310
+ x = randn (x_shape)
311
+ complex_x = x + randn (x_shape) * im
299
312
for dims in unique ((1 , 1 : N, N))
300
313
# fft, ifft, bfft
301
314
for f in (fft, ifft, bfft)
@@ -305,17 +318,17 @@ end
305
318
test_rrule (f, complex_x, dims)
306
319
end
307
320
for pf in (plan_fft, plan_ifft, plan_bfft)
308
- test_frule (* , pf (x, dims) ⊢ NoTangent () , x)
309
- test_rrule (* , pf (x, dims) ⊢ NoTangent () , x)
310
- test_frule (* , pf (complex_x, dims) ⊢ NoTangent () , complex_x)
311
- test_rrule (* , pf (complex_x, dims) ⊢ NoTangent () , complex_x)
321
+ test_frule (* , pf (x, dims), x)
322
+ test_rrule (* , pf (x, dims), x)
323
+ test_frule (* , pf (complex_x, dims), complex_x)
324
+ test_rrule (* , pf (complex_x, dims), complex_x)
312
325
end
313
326
314
327
# rfft
315
328
test_frule (rfft, x, dims)
316
329
test_rrule (rfft, x, dims)
317
- test_frule (* , plan_rfft (x, dims) ⊢ NoTangent () , x)
318
- test_rrule (* , plan_rfft (x, dims) ⊢ NoTangent () , x)
330
+ test_frule (* , plan_rfft (x, dims), x)
331
+ test_rrule (* , plan_rfft (x, dims), x)
319
332
320
333
# irfft, brfft
321
334
for f in (irfft, brfft)
328
341
end
329
342
for pf in (plan_irfft, plan_brfft)
330
343
for d in (2 * size (x, first (dims)) - 1 , 2 * size (x, first (dims)) - 2 )
331
- test_frule (* , pf (complex_x, d, dims) ⊢ NoTangent () , complex_x)
332
- test_rrule (* , pf (complex_x, d, dims) ⊢ NoTangent () , complex_x)
344
+ test_frule (* , pf (complex_x, d, dims), complex_x)
345
+ test_rrule (* , pf (complex_x, d, dims), complex_x)
333
346
end
334
347
end
335
348
end
0 commit comments