Skip to content

Commit 6c81dfd

Browse files
committed
Test plan adjoints and AD rules
1 parent b3c2a09 commit 6c81dfd

File tree

2 files changed

+114
-16
lines changed

2 files changed

+114
-16
lines changed

test/runtests.jl

Lines changed: 92 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
using AbstractFFTs
44
using AbstractFFTs: Plan
55
using ChainRulesTestUtils
6+
using ChainRulesCore: NoTangent
67

78
using LinearAlgebra
89
using Random
@@ -197,6 +198,79 @@ end
197198
@test @inferred(f9(plan_fft(zeros(10), 1), 10)) == 1/10
198199
end
199200

201+
@testset "output size" begin
202+
@testset "complex fft output size" begin
203+
for x in (randn(3), randn(3, 4), randn(3, 4, 5))
204+
N = ndims(x)
205+
y = randn(size(x))
206+
for dims in unique((1, 1:N, N))
207+
P = plan_fft(x, dims)
208+
@test AbstractFFTs.output_size(P) == size(x)
209+
@test AbstractFFTs.output_size(P') == size(x)
210+
Pinv = plan_ifft(x)
211+
@test AbstractFFTs.output_size(Pinv) == size(x)
212+
@test AbstractFFTs.output_size(Pinv') == size(x)
213+
end
214+
end
215+
end
216+
@testset "real fft output size" begin
217+
for x in (randn(3), randn(4), randn(3, 4), randn(3, 4, 5)) # test odd and even lengths
218+
N = ndims(x)
219+
for dims in unique((1, 1:N, N))
220+
P = plan_rfft(x, dims)
221+
Px_sz = size(P * x)
222+
@test AbstractFFTs.output_size(P) == Px_sz
223+
@test AbstractFFTs.output_size(P') == size(x)
224+
y = randn(Px_sz) .+ randn(Px_sz) * im
225+
Pinv = plan_irfft(y, size(x)[first(dims)], dims)
226+
@test AbstractFFTs.output_size(Pinv) == size(Pinv * y)
227+
@test AbstractFFTs.output_size(Pinv') == size(y)
228+
end
229+
end
230+
end
231+
end
232+
233+
@testset "adjoint" begin
234+
@testset "complex fft adjoint" begin
235+
for x in (randn(3), randn(3, 4), randn(3, 4, 5))
236+
N = ndims(x)
237+
y = randn(size(x))
238+
for dims in unique((1, 1:N, N))
239+
P = plan_fft(x, dims)
240+
@test (P')' * x == P * x # test adjoint of adjoint
241+
@test size(P') == AbstractFFTs.output_size(P) # test size of adjoint
242+
@test dot(y, P * x) dot(P' * y, x) # test validity of adjoint
243+
@test_broken dot(y, P \ x) dot(P' \ y, x)
244+
Pinv = plan_ifft(y)
245+
@test (Pinv')' * y == Pinv * y
246+
@test size(Pinv') == AbstractFFTs.output_size(Pinv)
247+
@test dot(x, Pinv * y) dot(Pinv' * x, y)
248+
@test_broken dot(x, Pinv \ y) dot(Pinv' \ x, y)
249+
end
250+
end
251+
end
252+
@testset "real fft adjoint" begin
253+
for x in (randn(3), randn(4), randn(3, 4), randn(3, 4, 5)) # test odd and even lengths
254+
N = ndims(x)
255+
for dims in unique((1, 1:N, N))
256+
P = plan_rfft(x, dims)
257+
y_real = randn(size(P * x))
258+
y_imag = randn(size(P * x))
259+
y = y_real .+ y_imag .* im
260+
@test (P')' * x == P * x
261+
@test size(P') == AbstractFFTs.output_size(P)
262+
@test dot(y_real, real.(P * x)) + dot(y_imag, imag.(P * x)) dot(P' * y, x)
263+
@test_broken dot(y_real, real.(P \ x)) + dot(y_imag, imag.(P \ x)) dot(P' * y, x)
264+
Pinv = plan_irfft(y, size(x)[first(dims)], dims)
265+
@test (Pinv')' * y == Pinv * y
266+
@test size(Pinv') == AbstractFFTs.output_size(Pinv)
267+
@test dot(x, Pinv * y) dot(y_real, real.(Pinv' * x)) + dot(y_imag, imag.(Pinv' * x))
268+
@test_broken dot(x, Pinv \ y) dot(y_real, real.(Pinv' \ x)) + dot(y_imag, imag.(Pinv' \ x))
269+
end
270+
end
271+
end
272+
end
273+
200274
@testset "ChainRules" begin
201275
@testset "shift functions" begin
202276
for x in (randn(3), randn(3, 4), randn(3, 4, 5))
@@ -218,20 +292,31 @@ end
218292
end
219293

220294
@testset "fft" begin
221-
for x in (randn(3), randn(3, 4), randn(3, 4, 5))
295+
for x in (randn(2), randn(2, 3), randn(3, 4, 5))
222296
N = ndims(x)
223297
complex_x = complex.(x)
224298
for dims in unique((1, 1:N, N))
299+
# fft, ifft, bfft
225300
for f in (fft, ifft, bfft)
226301
test_frule(f, x, dims)
227302
test_rrule(f, x, dims)
228303
test_frule(f, complex_x, dims)
229304
test_rrule(f, complex_x, dims)
230305
end
306+
for pf in (plan_fft, plan_ifft, plan_bfft)
307+
test_frule(*, pf(x, dims) NoTangent(), x)
308+
test_rrule(*, pf(x, dims) NoTangent(), x)
309+
test_frule(*, pf(complex_x, dims) NoTangent(), complex_x)
310+
test_rrule(*, pf(complex_x, dims) NoTangent(), complex_x)
311+
end
231312

313+
# rfft
232314
test_frule(rfft, x, dims)
233315
test_rrule(rfft, x, dims)
316+
test_frule(*, plan_rfft(x, dims) NoTangent(), x)
317+
test_rrule(*, plan_rfft(x, dims) NoTangent(), x)
234318

319+
# irfft, brfft
235320
for f in (irfft, brfft)
236321
for d in (2 * size(x, first(dims)) - 1, 2 * size(x, first(dims)) - 2)
237322
test_frule(f, x, d, dims)
@@ -240,6 +325,12 @@ end
240325
test_rrule(f, complex_x, d, dims)
241326
end
242327
end
328+
for pf in (plan_irfft, plan_brfft)
329+
for d in (2 * size(x, first(dims)) - 1, 2 * size(x, first(dims)) - 2)
330+
test_frule(*, pf(complex_x, d, dims) NoTangent(), complex_x)
331+
test_rrule(*, pf(complex_x, d, dims) NoTangent(), complex_x)
332+
end
333+
end
243334
end
244335
end
245336
end

test/testplans.jl

Lines changed: 22 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,18 @@
1-
mutable struct TestPlan{T,N} <: Plan{T}
2-
region
1+
mutable struct TestPlan{T,N,G} <: Plan{T}
2+
region::G
33
sz::NTuple{N,Int}
44
pinv::Plan{T}
5-
function TestPlan{T}(region, sz::NTuple{N,Int}) where {T,N}
6-
return new{T,N}(region, sz)
5+
function TestPlan{T}(region::G, sz::NTuple{N,Int}) where {T,N,G}
6+
return new{T,N,G}(region, sz)
77
end
88
end
99

10-
mutable struct InverseTestPlan{T,N} <: Plan{T}
11-
region
10+
mutable struct InverseTestPlan{T,N,G} <: Plan{T}
11+
region::G
1212
sz::NTuple{N,Int}
1313
pinv::Plan{T}
14-
function InverseTestPlan{T}(region, sz::NTuple{N,Int}) where {T,N}
15-
return new{T,N}(region, sz)
14+
function InverseTestPlan{T}(region::G, sz::NTuple{N,Int}) where {T,N,G}
15+
return new{T,N,G}(region, sz)
1616
end
1717
end
1818

@@ -21,6 +21,9 @@ Base.ndims(::TestPlan{T,N}) where {T,N} = N
2121
Base.size(p::InverseTestPlan) = p.sz
2222
Base.ndims(::InverseTestPlan{T,N}) where {T,N} = N
2323

24+
AbstractFFTs.ProjectionStyle(::TestPlan) = AbstractFFTs.NoProjectionStyle()
25+
AbstractFFTs.ProjectionStyle(::InverseTestPlan) = AbstractFFTs.NoProjectionStyle()
26+
2427
function AbstractFFTs.plan_fft(x::AbstractArray{T}, region; kwargs...) where {T}
2528
return TestPlan{T}(region, size(x))
2629
end
@@ -90,24 +93,28 @@ end
9093
Base.:*(p::TestPlan, x::AbstractArray) = mul!(similar(x, complex(float(eltype(x)))), p, x)
9194
Base.:*(p::InverseTestPlan, x::AbstractArray) = mul!(similar(x, complex(float(eltype(x)))), p, x)
9295

93-
mutable struct TestRPlan{T,N} <: Plan{T}
94-
region
96+
mutable struct TestRPlan{T,N,G} <: Plan{T}
97+
region::G
9598
sz::NTuple{N,Int}
9699
pinv::Plan{T}
97-
TestRPlan{T}(region, sz::NTuple{N,Int}) where {T,N} = new{T,N}(region, sz)
100+
TestRPlan{T}(region::G, sz::NTuple{N,Int}) where {T,N,G} = new{T,N,G}(region, sz)
98101
end
99102

100-
mutable struct InverseTestRPlan{T,N} <: Plan{T}
103+
mutable struct InverseTestRPlan{T,N,G} <: Plan{T}
101104
d::Int
102-
region
105+
region::G
103106
sz::NTuple{N,Int}
104107
pinv::Plan{T}
105-
function InverseTestRPlan{T}(d::Int, region, sz::NTuple{N,Int}) where {T,N}
108+
function InverseTestRPlan{T}(d::Int, region::G, sz::NTuple{N,Int}) where {T,N,G}
106109
sz[first(region)::Int] == d ÷ 2 + 1 || error("incompatible dimensions")
107-
return new{T,N}(d, region, sz)
110+
return new{T,N,G}(d, region, sz)
108111
end
109112
end
110113

114+
AbstractFFTs.ProjectionStyle(::TestRPlan) = AbstractFFTs.RealProjectionStyle()
115+
AbstractFFTs.ProjectionStyle(::InverseTestRPlan) = AbstractFFTs.RealInverseProjectionStyle()
116+
AbstractFFTs.irfft_dim(p::InverseTestRPlan) = p.d
117+
111118
function AbstractFFTs.plan_rfft(x::AbstractArray{T}, region; kwargs...) where {T}
112119
return TestRPlan{T}(region, size(x))
113120
end

0 commit comments

Comments
 (0)