Skip to content

Commit 1c69f6c

Browse files
committed
Add TestUtils submodule/extension
1 parent 1cc9ca0 commit 1c69f6c

File tree

6 files changed

+309
-146
lines changed

6 files changed

+309
-146
lines changed

Project.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,15 @@ version = "1.4.0"
55
[deps]
66
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
77
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
8+
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
89

910
[weakdeps]
1011
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
12+
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
1113

1214
[extensions]
1315
AbstractFFTsChainRulesCoreExt = "ChainRulesCore"
16+
AbstractFFTsTestUtilsExt = "Test"
1417

1518
[compat]
1619
ChainRulesCore = "1"

ext/AbstractFFTsTestUtilsExt.jl

Lines changed: 234 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,234 @@
1+
# This file contains code that was formerly part of Julia. License is MIT: https://julialang.org/license
2+
3+
module AbstractFFTsTestUtilsExt
4+
5+
using AbstractFFTs
6+
using AbstractFFTs: TestUtils
7+
using AbstractFFTs.LinearAlgebra
8+
using Test
9+
10+
# Ground truth _x_fft computed using FFTW library
11+
const TEST_CASES = (
12+
(; x = collect(1:7), dims = 1,
13+
x_fft = [28.0 + 0.0im,
14+
-3.5 + 7.267824888003178im,
15+
-3.5 + 2.7911568610884143im,
16+
-3.5 + 0.7988521603655248im,
17+
-3.5 - 0.7988521603655248im,
18+
-3.5 - 2.7911568610884143im,
19+
-3.5 - 7.267824888003178im]),
20+
(; x = collect(1:8), dims = 1,
21+
x_fft = [36.0 + 0.0im,
22+
-4.0 + 9.65685424949238im,
23+
-4.0 + 4.0im,
24+
-4.0 + 1.6568542494923806im,
25+
-4.0 + 0.0im,
26+
-4.0 - 1.6568542494923806im,
27+
-4.0 - 4.0im,
28+
-4.0 - 9.65685424949238im]),
29+
(; x = collect(reshape(1:8, 2, 4)), dims = 2,
30+
x_fft = [16.0+0.0im -4.0+4.0im -4.0+0.0im -4.0-4.0im;
31+
20.0+0.0im -4.0+4.0im -4.0+0.0im -4.0-4.0im]),
32+
(; x = collect(reshape(1:9, 3, 3)), dims = 2,
33+
x_fft = [12.0+0.0im -4.5+2.598076211353316im -4.5-2.598076211353316im;
34+
15.0+0.0im -4.5+2.598076211353316im -4.5-2.598076211353316im;
35+
18.0+0.0im -4.5+2.598076211353316im -4.5-2.598076211353316im]),
36+
(; x = collect(reshape(1:8, 2, 2, 2)), dims = 1:2,
37+
x_fft = cat([10.0 + 0.0im -4.0 + 0.0im; -2.0 + 0.0im 0.0 + 0.0im],
38+
[26.0 + 0.0im -4.0 + 0.0im; -2.0 + 0.0im 0.0 + 0.0im],
39+
dims=3)),
40+
(; x = collect(1:7) + im * collect(8:14), dims = 1,
41+
x_fft = [28.0 + 77.0im,
42+
-10.76782488800318 + 3.767824888003175im,
43+
-6.291156861088416 - 0.7088431389115883im,
44+
-4.298852160365525 - 2.7011478396344746im,
45+
-2.7011478396344764 - 4.298852160365524im,
46+
-0.7088431389115866 - 6.291156861088417im,
47+
3.767824888003177 - 10.76782488800318im]),
48+
(; x = collect(reshape(1:8, 2, 2, 2)) + im * reshape(9:16, 2, 2, 2), dims = 1:2,
49+
x_fft = cat([10.0 + 42.0im -4.0 - 4.0im; -2.0 - 2.0im 0.0 + 0.0im],
50+
[26.0 + 58.0im -4.0 - 4.0im; -2.0 - 2.0im 0.0 + 0.0im],
51+
dims=3)),
52+
)
53+
54+
"""
55+
TestUtils.test_complex_fft(ArrayType=Array; test_real=true, test_inplace=true)
56+
57+
Run tests to verify correctness of FFT/BFFT/IFFT functionality using a particular backend plan implementation.
58+
The backend implementation is assumed to be loaded prior to calling this function.
59+
60+
# Arguments
61+
62+
- `ArrayType`: determines the `AbstractArray` implementation for
63+
which the correctness tests are run. Arrays are constructed via
64+
`convert(ArrayType, ...)`.
65+
- `test_inplace=true`: whether to test in-place plans.
66+
"""
67+
function TestUtils.test_complex_fft(ArrayType=Array; test_inplace=true)
68+
@testset "correctness of fft, bfft, ifft" begin
69+
for test_case in TEST_CASES
70+
_x, dims, _x_fft = test_case.x, test_case.dims, test_case.x_fft
71+
x = convert(ArrayType, _x) # dummy array that will be passed to plans
72+
x_complexf = convert(ArrayType, complex.(float.(x))) # for testing mutating complex FFTs
73+
x_fft = convert(ArrayType, _x_fft)
74+
75+
# FFT
76+
@test fft(x, dims) x_fft
77+
if test_inplace
78+
_x_complexf = copy(x_complexf)
79+
@test fft!(_x_complexf, dims) x_fft
80+
@test _x_complexf x_fft
81+
end
82+
# test OOP plans, checking plan_fft and also inv of plan_ifft,
83+
# which should give functionally identical plans
84+
for P in (plan_fft(similar(x_complexf), dims), inv(plan_ifft(similar(x_complexf), dims)))
85+
@test eltype(P) <: Complex
86+
@test fftdims(P) == dims
87+
@test P * x x_fft
88+
@test P \ (P * x) x
89+
_x_out = similar(x_fft)
90+
@test mul!(_x_out, P, x_complexf) x_fft
91+
@test _x_out x_fft
92+
end
93+
if test_inplace
94+
# test IIP plans
95+
for P in (plan_fft!(similar(x_complexf), dims), inv(plan_ifft!(similar(x_complexf), dims)))
96+
@test eltype(P) <: Complex
97+
@test fftdims(P) == dims
98+
_x_complexf = copy(x_complexf)
99+
@test P * _x_complexf x_fft
100+
@test _x_complexf x_fft
101+
@test P \ _x_complexf x
102+
@test _x_complexf x
103+
end
104+
end
105+
106+
# BFFT
107+
x_scaled = prod(size(x, d) for d in dims) .* x
108+
@test bfft(x_fft, dims) x_scaled
109+
if test_inplace
110+
_x_fft = copy(x_fft)
111+
@test bfft!(_x_fft, dims) x_scaled
112+
@test _x_fft x_scaled
113+
end
114+
# test OOP plans. Just 1 plan to test, but we use a for loop for consistent style
115+
for P in (plan_bfft(similar(x_fft), dims),)
116+
@test eltype(P) <: Complex
117+
@test fftdims(P) == dims
118+
@test P * x_fft x_scaled
119+
@test P \ (P * x_fft) x_fft
120+
_x_complexf = similar(x_complexf)
121+
@test mul!(_x_complexf, P, x_fft) x_scaled
122+
@test _x_complexf x_scaled
123+
end
124+
# test IIP plans
125+
for P in (plan_bfft!(similar(x_fft), dims),)
126+
@test eltype(P) <: Complex
127+
@test fftdims(P) == dims
128+
_x_fft = copy(x_fft)
129+
@test P * _x_fft x_scaled
130+
@test _x_fft x_scaled
131+
@test P \ _x_fft x_fft
132+
@test _x_fft x_fft
133+
end
134+
135+
# IFFT
136+
@test ifft(x_fft, dims) x
137+
if test_inplace
138+
_x_fft = copy(x_fft)
139+
@test ifft!(_x_fft, dims) x
140+
@test _x_fft x
141+
end
142+
# test OOP plans
143+
for P in (plan_ifft(similar(x_complexf), dims), inv(plan_fft(similar(x_complexf), dims)))
144+
@test eltype(P) <: Complex
145+
@test fftdims(P) == dims
146+
@test P * x_fft x
147+
@test P \ (P * x_fft) x_fft
148+
_x_complexf = similar(x_complexf)
149+
@test mul!(_x_complexf, P, x_fft) x
150+
@test _x_complexf x
151+
end
152+
# test IIP plans
153+
if test_inplace
154+
for P in (plan_ifft!(similar(x_complexf), dims), inv(plan_fft!(similar(x_complexf), dims)))
155+
@test eltype(P) <: Complex
156+
@test fftdims(P) == dims
157+
_x_fft = copy(x_fft)
158+
@test P * _x_fft x
159+
@test _x_fft x
160+
@test P \ _x_fft x_fft
161+
@test _x_fft x_fft
162+
end
163+
end
164+
end
165+
end
166+
end
167+
168+
"""
169+
TestUtils.test_real_fft(ArrayType=Array; test_real=true, test_inplace=true)
170+
171+
Run tests to verify correctness of RFFT/BRFFT/IRFFT functionality using a particular backend plan implementation.
172+
The backend implementation is assumed to be loaded prior to calling this function.
173+
174+
# Arguments
175+
176+
- `ArrayType`: determines the `AbstractArray` implementation for
177+
which the correctness tests are run. Arrays are constructed via
178+
`convert(ArrayType, ...)`.
179+
- `test_inplace=true`: whether to test in-place plans.
180+
"""
181+
function TestUtils.test_real_fft(ArrayType=Array; test_inplace=true)
182+
@testset "correctness of rfft, brfft, irfft" begin
183+
for test_case in TEST_CASES[5:5]
184+
_x, dims, _x_fft = test_case.x, test_case.dims, test_case.x_fft
185+
x = convert(ArrayType, _x) # dummy array that will be passed to plans
186+
x_real = float.(x) # for testing mutating real FFTs
187+
x_fft = convert(ArrayType, _x_fft)
188+
x_rfft = selectdim(x_fft, first(dims), 1:(size(x_fft, first(dims)) ÷ 2 + 1))
189+
190+
if !(eltype(x) <: Real)
191+
continue
192+
end
193+
194+
# RFFT
195+
@test rfft(x, dims) x_rfft
196+
for P in (plan_rfft(similar(x_real), dims), inv(plan_irfft(similar(x_rfft), size(x, first(dims)), dims)))
197+
@test eltype(P) <: Real
198+
@test fftdims(P) == dims
199+
# Always copy input before application due to FFTW real plans possibly mutating input (AbstractFFTs.jl#101)
200+
@test P * copy(x) x_rfft
201+
@test P \ (P * copy(x)) x
202+
_x_rfft = similar(x_rfft)
203+
@test mul!(_x_rfft, P, copy(x_real)) x_rfft
204+
@test _x_rfft x_rfft
205+
end
206+
207+
# BRFFT
208+
x_scaled = prod(size(x, d) for d in dims) .* x
209+
@test brfft(x_rfft, size(x, first(dims)), dims) x_scaled
210+
for P in (plan_brfft(similar(x_rfft), size(x, first(dims)), dims),)
211+
@test eltype(P) <: Complex
212+
@test fftdims(P) == dims
213+
@test P * copy(x_rfft) x_scaled
214+
@test P \ (P * copy(x_rfft)) x_rfft
215+
_x_scaled = similar(x_real)
216+
@test mul!(_x_scaled, P, copy(x_rfft)) x_scaled
217+
@test _x_scaled x_scaled
218+
end
219+
220+
# IRFFT
221+
@test irfft(x_rfft, size(x, first(dims)), dims) x
222+
for P in (plan_irfft(similar(x_rfft), size(x, first(dims)), dims), inv(plan_rfft(similar(x_real), dims)))
223+
@test eltype(P) <: Complex
224+
@test fftdims(P) == dims
225+
@test P * copy(x_rfft) x
226+
@test P \ (P * copy(x_rfft)) x_rfft
227+
_x_real = similar(x_real)
228+
@test mul!(_x_real, P, copy(x_rfft)) x_real
229+
end
230+
end
231+
end
232+
end
233+
234+
end

src/AbstractFFTs.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,11 @@ export fft, ifft, bfft, fft!, ifft!, bfft!,
66
fftdims, fftshift, ifftshift, fftshift!, ifftshift!, Frequencies, fftfreq, rfftfreq
77

88
include("definitions.jl")
9+
include("TestUtils.jl")
910

1011
if !isdefined(Base, :get_extension)
1112
include("../ext/AbstractFFTsChainRulesCoreExt.jl")
13+
include("../ext/AbstractFFTsTestUtilsExt.jl")
1214
end
1315

1416
end # module

src/TestUtils.jl

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
module TestUtils
2+
3+
function test_complex_fft end
4+
function test_real_fft end
5+
6+
function __init__()
7+
# Better error message if users forget to load Test
8+
Base.Experimental.register_error_hint(MethodError) do io, exc, _, _
9+
if exc.f in (test_real_fft, test_complex_fft)
10+
print(io, "\nDid you forget to load Test?")
11+
end
12+
end
13+
end
14+
15+
end

0 commit comments

Comments
 (0)