Skip to content

Commit 5482609

Browse files
committed
Cleanup fft backend tests
1 parent bcdb4f5 commit 5482609

File tree

2 files changed

+26
-24
lines changed

2 files changed

+26
-24
lines changed

src/TestUtils.jl

Lines changed: 13 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -10,22 +10,12 @@ using AbstractFFTs: Plan
1010
using LinearAlgebra
1111
using Test
1212

13-
import Unitful
14-
15-
function test_fft_backend()
16-
@testset "rfft sizes" begin
17-
A = rand(11, 10)
18-
@test @inferred(AbstractFFTs.rfft_output_size(A, 1)) == (6, 10)
19-
@test @inferred(AbstractFFTs.rfft_output_size(A, 2)) == (11, 6)
20-
A1 = rand(6, 10); A2 = rand(11, 6)
21-
@test @inferred(AbstractFFTs.brfft_output_size(A1, 11, 1)) == (11, 10)
22-
@test @inferred(AbstractFFTs.brfft_output_size(A2, 10, 2)) == (11, 10)
23-
@test_throws AssertionError AbstractFFTs.brfft_output_size(A1, 10, 2)
24-
end
25-
13+
"""
14+
"""
15+
function test_fft_backend(array_constructor)
2616
@testset "fft correctness" begin
2717
# DFT along last dimension, results computed using FFTW
28-
for (x, fftw_fft) in (
18+
for (_x, _fftw_fft) in (
2919
(collect(1:7),
3020
[28.0 + 0.0im,
3121
-3.5 + 7.267824888003178im,
@@ -51,6 +41,9 @@ function test_fft_backend()
5141
15.0+0.0im -4.5+2.598076211353316im -4.5-2.598076211353316im;
5242
18.0+0.0im -4.5+2.598076211353316im -4.5-2.598076211353316im]),
5343
)
44+
x = array_constructor(_x)
45+
fftw_fft = array_constructor(_fftw_fft)
46+
5447
# FFT
5548
dims = ndims(x)
5649
y = AbstractFFTs.fft(x, dims)
@@ -59,39 +52,37 @@ function test_fft_backend()
5952
# functionally identical plans
6053
for P in [plan_fft(x, dims), inv(plan_ifft(x, dims)),
6154
AbstractFFTs.plan_inv(plan_ifft(x, dims))]
62-
@test eltype(P) === ComplexF64
55+
@test eltype(P) <: Complex
6356
@test P * x fftw_fft
6457
@test P \ (P * x) x
6558
@test fftdims(P) == dims
6659
end
6760

61+
# BFFT
6862
fftw_bfft = complex.(size(x, dims) .* x)
6963
@test AbstractFFTs.bfft(y, dims) fftw_bfft
7064
P = plan_bfft(x, dims)
7165
@test P * y fftw_bfft
7266
@test P \ (P * y) y
7367
@test fftdims(P) == dims
7468

69+
# IFFT
7570
fftw_ifft = complex.(x)
7671
@test AbstractFFTs.ifft(y, dims) fftw_ifft
77-
# test plan_ifft and also inv and plan_inv of plan_fft, which should all give
78-
# functionally identical plans
7972
for P in [plan_ifft(x, dims), inv(plan_fft(x, dims)),
8073
AbstractFFTs.plan_inv(plan_fft(x, dims))]
8174
@test P * y fftw_ifft
8275
@test P \ (P * y) y
8376
@test fftdims(P) == dims
8477
end
8578

86-
# real FFT
79+
# RFFT
8780
fftw_rfft = fftw_fft[
8881
(Colon() for _ in 1:(ndims(fftw_fft) - 1))...,
8982
1:(size(fftw_fft, ndims(fftw_fft)) ÷ 2 + 1)
9083
]
9184
ry = AbstractFFTs.rfft(x, dims)
9285
@test ry fftw_rfft
93-
# test plan_rfft and also inv and plan_inv of plan_irfft, which should all give
94-
# functionally identical plans
9586
for P in [plan_rfft(x, dims), inv(plan_irfft(ry, size(x, dims), dims)),
9687
AbstractFFTs.plan_inv(plan_irfft(ry, size(x, dims), dims))]
9788
@test eltype(P) <: Real
@@ -100,17 +91,17 @@ function test_fft_backend()
10091
@test fftdims(P) == dims
10192
end
10293

94+
# BRFFT
10395
fftw_brfft = complex.(size(x, dims) .* x)
10496
@test AbstractFFTs.brfft(ry, size(x, dims), dims) fftw_brfft
10597
P = plan_brfft(ry, size(x, dims), dims)
10698
@test P * ry fftw_brfft
10799
@test P \ (P * ry) ry
108100
@test fftdims(P) == dims
109101

102+
# IRFFT
110103
fftw_irfft = complex.(x)
111104
@test AbstractFFTs.irfft(ry, size(x, dims), dims) fftw_irfft
112-
# test plan_rfft and also inv and plan_inv of plan_irfft, which should all give
113-
# functionally identical plans
114105
for P in [plan_irfft(ry, size(x, dims), dims), inv(plan_rfft(x, dims)),
115106
AbstractFFTs.plan_inv(plan_rfft(x, dims))]
116107
@test P * ry fftw_irfft

test/runtests.jl

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,24 @@ using Test
33
using AbstractFFTs
44
using AbstractFFTs.TestUtils
55
using ChainRulesTestUtils
6+
import Unitful
67

78
Random.seed!(1234)
89

910
include("TestPlans.jl")
1011

1112
using .TestPlans
12-
test_fft_backend() # Tests for FFT plans (and operations in AbstractFFTs that derive from them)
13+
test_fft_backend(Array) # Tests for FFT plans (and operations in AbstractFFTs that derive from them)
14+
15+
@testset "rfft sizes" begin
16+
A = rand(11, 10)
17+
@test @inferred(AbstractFFTs.rfft_output_size(A, 1)) == (6, 10)
18+
@test @inferred(AbstractFFTs.rfft_output_size(A, 2)) == (11, 6)
19+
A1 = rand(6, 10); A2 = rand(11, 6)
20+
@test @inferred(AbstractFFTs.brfft_output_size(A1, 11, 1)) == (11, 10)
21+
@test @inferred(AbstractFFTs.brfft_output_size(A2, 10, 2)) == (11, 10)
22+
@test_throws AssertionError AbstractFFTs.brfft_output_size(A1, 10, 2)
23+
end
1324

1425
@testset "Shift functions" begin
1526
@test @inferred(AbstractFFTs.fftshift([1 2 3])) == [3 1 2]
@@ -99,7 +110,7 @@ end
99110
# normalization should be inferable even if region is only inferred as ::Any,
100111
# need to wrap in another function to test this (note that p.region::Any for
101112
# p::TestPlan)
102-
f9(p::Plan{T}, sz) where {T} = AbstractFFTs.normalization(real(T), sz, fftdims(p))
113+
f9(p::AbstractFFTs.Plan{T}, sz) where {T} = AbstractFFTs.normalization(real(T), sz, fftdims(p))
103114
@test @inferred(f9(plan_fft(zeros(10), 1), 10)) == 1/10
104115
end
105116

0 commit comments

Comments
 (0)