Skip to content

Commit 09d8383

Browse files
committed
Add adjoint testing to test utilities
1 parent babfca6 commit 09d8383

File tree

3 files changed

+38
-52
lines changed

3 files changed

+38
-52
lines changed

ext/AbstractFFTsTestUtilsExt.jl

Lines changed: 37 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,26 @@ const TEST_CASES = (
5151
dims=3)),
5252
)
5353

54+
# Perform generic adjoint plan tests
55+
function _adjoint_test(P, x; real_plan=false)
56+
y = rand(eltype(P * x), size(P * x))
57+
# test basic properties
58+
@test_broken eltype(P') === typeof(y) # (AbstactFFTs.jl#110)
59+
@test fftdims(P') == fftdims(P)
60+
@test (P')' === P # test adjoint of adjoint
61+
@test size(P') == AbstractFFTs.output_size(P) # test size of adjoint
62+
# test correctness of adjoint and its inverse via the dot test
63+
if !real_plan
64+
@test dot(y, P * x) dot(P' * y, x)
65+
@test dot(y, P \ x) dot(P' \ y, x)
66+
else
67+
_component_dot(x, y) = dot(real.(x), real.(y)) + dot(imag.(x), imag.(y))
68+
@test _component_dot(y, P * copy(x)) _component_dot(P' * copy(y), x)
69+
@test _component_dot(x, P \ copy(y)) _component_dot(P' \ copy(x), y)
70+
end
71+
@test_throws MethodError mul!(x, P', y)
72+
end
73+
5474
"""
5575
TestUtils.test_complex_fft(ArrayType=Array; test_real=true, test_inplace=true)
5676
@@ -63,8 +83,9 @@ The backend implementation is assumed to be loaded prior to calling this functio
6383
which the correctness tests are run. Arrays are constructed via
6484
`convert(ArrayType, ...)`.
6585
- `test_inplace=true`: whether to test in-place plans.
86+
- `test_adjoint=true`: whether to test adjoints of plans.
6687
"""
67-
function TestUtils.test_complex_fft(ArrayType=Array; test_inplace=true)
88+
function TestUtils.test_complex_fft(ArrayType=Array; test_inplace=true, test_adjoint=true)
6889
@testset "correctness of fft, bfft, ifft" begin
6990
for test_case in TEST_CASES
7091
_x, dims, _x_fft = test_case.x, test_case.dims, test_case.x_fft
@@ -89,6 +110,9 @@ function TestUtils.test_complex_fft(ArrayType=Array; test_inplace=true)
89110
_x_out = similar(x_fft)
90111
@test mul!(_x_out, P, x_complexf) x_fft
91112
@test _x_out x_fft
113+
if test_adjoint
114+
_adjoint_test(P, x_complexf)
115+
end
92116
end
93117
if test_inplace
94118
# test IIP plans
@@ -120,6 +144,9 @@ function TestUtils.test_complex_fft(ArrayType=Array; test_inplace=true)
120144
_x_complexf = similar(x_complexf)
121145
@test mul!(_x_complexf, P, x_fft) x_scaled
122146
@test _x_complexf x_scaled
147+
if test_adjoint
148+
_adjoint_test(P, x_complexf)
149+
end
123150
end
124151
# test IIP plans
125152
for P in (plan_bfft!(similar(x_fft), dims),)
@@ -148,6 +175,9 @@ function TestUtils.test_complex_fft(ArrayType=Array; test_inplace=true)
148175
_x_complexf = similar(x_complexf)
149176
@test mul!(_x_complexf, P, x_fft) x
150177
@test _x_complexf x
178+
if test_adjoint
179+
_adjoint_test(P, x_complexf)
180+
end
151181
end
152182
# test IIP plans
153183
if test_inplace
@@ -177,10 +207,11 @@ The backend implementation is assumed to be loaded prior to calling this functio
177207
which the correctness tests are run. Arrays are constructed via
178208
`convert(ArrayType, ...)`.
179209
- `test_inplace=true`: whether to test in-place plans.
210+
- `test_adjoint=true`: whether to test adjoints of plans.
180211
"""
181-
function TestUtils.test_real_fft(ArrayType=Array; test_inplace=true)
212+
function TestUtils.test_real_fft(ArrayType=Array; test_inplace=true, test_adjoint=true)
182213
@testset "correctness of rfft, brfft, irfft" begin
183-
for test_case in TEST_CASES[5:5]
214+
for test_case in TEST_CASES
184215
_x, dims, _x_fft = test_case.x, test_case.dims, test_case.x_fft
185216
x = convert(ArrayType, _x) # dummy array that will be passed to plans
186217
x_real = float.(x) # for testing mutating real FFTs
@@ -202,6 +233,9 @@ function TestUtils.test_real_fft(ArrayType=Array; test_inplace=true)
202233
_x_rfft = similar(x_rfft)
203234
@test mul!(_x_rfft, P, copy(x_real)) x_rfft
204235
@test _x_rfft x_rfft
236+
if test_adjoint
237+
_adjoint_test(P, x_real; real_plan=true)
238+
end
205239
end
206240

207241
# BRFFT

src/definitions.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -619,6 +619,7 @@ Base.adjoint(p::ScaledPlan) = ScaledPlan(p.p', p.scale)
619619

620620
size(p::AdjointPlan) = output_size(p.p)
621621
output_size(p::AdjointPlan) = size(p.p)
622+
fftdims(p::AdjointPlan) = fftdims(p.p)
622623

623624
Base.:*(p::AdjointPlan, x::AbstractArray) = _mul(p, x, ProjectionStyle(p.p))
624625

test/runtests.jl

Lines changed: 0 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -153,55 +153,6 @@ end
153153
end
154154
end
155155

156-
@testset "adjoint" begin
157-
@testset "complex fft adjoint" begin
158-
for x_shape in ((3,), (3, 4), (3, 4, 5))
159-
N = length(x_shape)
160-
real_x = randn(x_shape)
161-
complex_x = randn(ComplexF64, x_shape)
162-
y = randn(ComplexF64, x_shape)
163-
for x in (real_x, complex_x)
164-
for dims in unique((1, 1:N, N))
165-
P = plan_fft(x, dims)
166-
@test (P')' === P # test adjoint of adjoint
167-
@test size(P') == AbstractFFTs.output_size(P) # test size of adjoint
168-
@test dot(y, P * x) dot(P' * y, x) # test validity of adjoint
169-
@test dot(y, P \ x) dot(P' \ y, x) # test inv of adjoint
170-
@test dot(y, P \ x) dot(AbstractFFTs.plan_inv(P') * y, x) # test plan_inv of adjoint
171-
Pinv = plan_ifft(y)
172-
@test (Pinv')' * y == Pinv * y
173-
@test size(Pinv') == AbstractFFTs.output_size(Pinv)
174-
@test dot(x, Pinv * y) dot(Pinv' * x, y)
175-
@test dot(x, Pinv \ y) dot(Pinv' \ x, y)
176-
@test dot(x, Pinv \ y) dot(AbstractFFTs.plan_inv(Pinv') * x, y)
177-
@test_throws MethodError mul!(x, P', y)
178-
end
179-
end
180-
end
181-
end
182-
@testset "real fft adjoint" begin
183-
for x in (randn(3), randn(4), randn(3, 4), randn(3, 4, 5)) # test odd and even lengths
184-
N = ndims(x)
185-
for dims in unique((1, 1:N, N))
186-
P = plan_rfft(x, dims)
187-
y = randn(ComplexF64, size(P * x))
188-
@test (P')' * x == P * x
189-
@test size(P') == AbstractFFTs.output_size(P)
190-
@test dot(real.(y), real.(P * x)) + dot(imag.(y), imag.(P * x)) dot(P' * y, x)
191-
@test dot(real.(y), real.(P' \ x)) + dot(imag.(y), imag.(P' \ x)) dot(P \ y, x)
192-
@test dot(real.(y), real.(AbstractFFTs.plan_inv(P') * x)) +
193-
dot(imag.(y), imag.(AbstractFFTs.plan_inv(P') * x)) dot(P \ y, x)
194-
Pinv = plan_irfft(y, size(x)[first(dims)], dims)
195-
@test (Pinv')' * y == Pinv * y
196-
@test size(Pinv') == AbstractFFTs.output_size(Pinv)
197-
@test dot(x, Pinv * y) dot(real.(y), real.(Pinv' * x)) + dot(imag.(y), imag.(Pinv' * x))
198-
@test dot(x, Pinv' \ y) dot(real.(y), real.(Pinv \ x)) + dot(imag.(y), imag.(Pinv \ x))
199-
@test dot(x, AbstractFFTs.plan_inv(Pinv') * y) dot(real.(y), real.(Pinv \ x)) + dot(imag.(y), imag.(Pinv \ x))
200-
end
201-
end
202-
end
203-
end
204-
205156
# Test that dims defaults to 1:ndims for fft-like functions
206157
@testset "Default dims" begin
207158
for x in (randn(3), randn(3, 4), randn(3, 4, 5))

0 commit comments

Comments
 (0)