Skip to content

Commit 91102a1

Browse files
gaurav-aryaGaurav Arya
authored andcommitted
Test plan adjoints and AD rules
1 parent 33e365e commit 91102a1

File tree

2 files changed

+71
-18
lines changed

2 files changed

+71
-18
lines changed

test/runtests.jl

Lines changed: 49 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
using AbstractFFTs
44
using AbstractFFTs: Plan
55
using ChainRulesTestUtils
6+
using ChainRulesCore: NoTangent
67

78
using LinearAlgebra
89
using Random
@@ -197,6 +198,39 @@ end
197198
@test @inferred(f9(plan_fft(zeros(10), 1), 10)) == 1/10
198199
end
199200

201+
@testset "adjoint" begin
202+
@testset "complex fft adjoint" begin
203+
for x in (randn(3), randn(3, 4), randn(3, 4, 5))
204+
N = ndims(x)
205+
y = randn(size(x))
206+
for dims in unique((1, 1:N, N))
207+
P = plan_fft(x, dims)
208+
@test dot(y, P * x) dot(P' * y, x)
209+
@test_broken dot(y, P \ x) dot(P' \ y, x)
210+
Pinv = plan_ifft(x)
211+
@test dot(x, Pinv * y) dot(Pinv' * x, y)
212+
@test_broken dot(x, Pinv \ y) dot(Pinv' \ x, y)
213+
end
214+
end
215+
end
216+
@testset "real fft adjoint" begin
217+
for x in (randn(3), randn(4), randn(3, 4), randn(3, 4, 5)) # test odd and even lengths
218+
N = ndims(x)
219+
for dims in unique((1, 1:N, N))
220+
P = plan_rfft(similar(x), dims)
221+
y_real = randn(size(P * x))
222+
y_imag = randn(size(P * x))
223+
y = y_real .+ y_imag .* im
224+
@test dot(y_real, real.(P * x)) + dot(y_imag, imag.(P * x)) dot(P' * y, x)
225+
@test_broken dot(y_real, real.(P \ x)) + dot(y_imag, imag.(P \ x)) dot(P' * y, x)
226+
Pinv = plan_irfft(similar(y), size(x)[first(dims)], dims)
227+
@test dot(x, Pinv * y) dot(y_real, real.(Pinv' * x)) + dot(y_imag, imag.(Pinv' * x))
228+
@test_broken dot(x, Pinv \ y) dot(y_real, real.(Pinv' \ x)) + dot(y_imag, imag.(Pinv' \ x))
229+
end
230+
end
231+
end
232+
end
233+
200234
@testset "ChainRules" begin
201235
@testset "shift functions" begin
202236
for x in (randn(3), randn(3, 4), randn(3, 4, 5))
@@ -218,7 +252,7 @@ end
218252
end
219253

220254
@testset "fft" begin
221-
for x in (randn(3), randn(3, 4), randn(3, 4, 5))
255+
for x in (randn(2), randn(2, 3), randn(3, 4, 5))
222256
N = ndims(x)
223257
complex_x = complex.(x)
224258
for dims in unique((1, 1:N, N))
@@ -229,8 +263,12 @@ end
229263
test_rrule(f, complex_x, dims)
230264
end
231265

232-
test_frule(rfft, x, dims)
233-
test_rrule(rfft, x, dims)
266+
for pf in (plan_fft, plan_ifft, plan_bfft)
267+
test_frule(*, pf(x, dims) NoTangent(), x)
268+
test_rrule(*, pf(x, dims) NoTangent(), x)
269+
test_frule(*, pf(complex_x, dims) NoTangent(), complex_x)
270+
test_rrule(*, pf(complex_x, dims) NoTangent(), complex_x)
271+
end
234272

235273
for f in (irfft, brfft)
236274
for d in (2 * size(x, first(dims)) - 1, 2 * size(x, first(dims)) - 2)
@@ -240,6 +278,14 @@ end
240278
test_rrule(f, complex_x, d, dims)
241279
end
242280
end
281+
282+
for pf in (plan_irfft, plan_brfft)
283+
for d in (2 * size(x, first(dims)) - 1, 2 * size(x, first(dims)) - 2)
284+
test_frule(*, pf(complex_x, d, dims) NoTangent(), complex_x)
285+
test_rrule(*, pf(complex_x, d, dims) NoTangent(), complex_x)
286+
end
287+
end
288+
243289
end
244290
end
245291
end

test/testplans.jl

Lines changed: 22 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,18 @@
1-
mutable struct TestPlan{T,N} <: Plan{T}
2-
region
1+
mutable struct TestPlan{T,N,G} <: Plan{T}
2+
region::G
33
sz::NTuple{N,Int}
44
pinv::Plan{T}
5-
function TestPlan{T}(region, sz::NTuple{N,Int}) where {T,N}
6-
return new{T,N}(region, sz)
5+
function TestPlan{T}(region::G, sz::NTuple{N,Int}) where {T,N,G}
6+
return new{T,N,G}(region, sz)
77
end
88
end
99

10-
mutable struct InverseTestPlan{T,N} <: Plan{T}
11-
region
10+
mutable struct InverseTestPlan{T,N,G} <: Plan{T}
11+
region::G
1212
sz::NTuple{N,Int}
1313
pinv::Plan{T}
14-
function InverseTestPlan{T}(region, sz::NTuple{N,Int}) where {T,N}
15-
return new{T,N}(region, sz)
14+
function InverseTestPlan{T}(region::G, sz::NTuple{N,Int}) where {T,N,G}
15+
return new{T,N,G}(region, sz)
1616
end
1717
end
1818

@@ -21,6 +21,9 @@ Base.ndims(::TestPlan{T,N}) where {T,N} = N
2121
Base.size(p::InverseTestPlan) = p.sz
2222
Base.ndims(::InverseTestPlan{T,N}) where {T,N} = N
2323

24+
AbstractFFTs.projection_style(::TestPlan) = :none
25+
AbstractFFTs.projection_style(::InverseTestPlan) = :none
26+
2427
function AbstractFFTs.plan_fft(x::AbstractArray{T}, region; kwargs...) where {T}
2528
return TestPlan{T}(region, size(x))
2629
end
@@ -90,24 +93,28 @@ end
9093
Base.:*(p::TestPlan, x::AbstractArray) = mul!(similar(x, complex(float(eltype(x)))), p, x)
9194
Base.:*(p::InverseTestPlan, x::AbstractArray) = mul!(similar(x, complex(float(eltype(x)))), p, x)
9295

93-
mutable struct TestRPlan{T,N} <: Plan{T}
94-
region
96+
mutable struct TestRPlan{T,N,G} <: Plan{T}
97+
region::G
9598
sz::NTuple{N,Int}
9699
pinv::Plan{T}
97-
TestRPlan{T}(region, sz::NTuple{N,Int}) where {T,N} = new{T,N}(region, sz)
100+
TestRPlan{T}(region::G, sz::NTuple{N,Int}) where {T,N,G} = new{T,N,G}(region, sz)
98101
end
99102

100-
mutable struct InverseTestRPlan{T,N} <: Plan{T}
103+
mutable struct InverseTestRPlan{T,N,G} <: Plan{T}
101104
d::Int
102-
region
105+
region::G
103106
sz::NTuple{N,Int}
104107
pinv::Plan{T}
105-
function InverseTestRPlan{T}(d::Int, region, sz::NTuple{N,Int}) where {T,N}
108+
function InverseTestRPlan{T}(d::Int, region::G, sz::NTuple{N,Int}) where {T,N,G}
106109
sz[first(region)::Int] == d ÷ 2 + 1 || error("incompatible dimensions")
107-
return new{T,N}(d, region, sz)
110+
return new{T,N,G}(d, region, sz)
108111
end
109112
end
110113

114+
AbstractFFTs.projection_style(::TestRPlan) = :real
115+
AbstractFFTs.projection_style(::InverseTestRPlan) = :real_inv
116+
AbstractFFTs.irfft_dim(p::InverseTestRPlan) = p.d
117+
111118
function AbstractFFTs.plan_rfft(x::AbstractArray{T}, region; kwargs...) where {T}
112119
return TestRPlan{T}(region, size(x))
113120
end

0 commit comments

Comments
 (0)