Skip to content

Commit 3b0225d

Browse files
committed
Add option to test with FFTW backend
1 parent 03ef58b commit 3b0225d

File tree

5 files changed

+483
-459
lines changed

5 files changed

+483
-459
lines changed

Project.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,10 @@ julia = "^1.0"
1212

1313
[extras]
1414
ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
15+
FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341"
1516
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1617
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
1718
Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d"
1819

1920
[targets]
20-
test = ["ChainRulesTestUtils", "Random", "Test", "Unitful"]
21+
test = ["ChainRulesTestUtils", "Random", "FFTW", "Test", "Unitful"]

test/TestPlans.jl

Lines changed: 235 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,235 @@
1+
module TestPlans
2+
3+
using AbstractFFTs
4+
using AbstractFFTs: Plan
5+
6+
mutable struct TestPlan{T,N} <: Plan{T}
7+
region
8+
sz::NTuple{N,Int}
9+
pinv::Plan{T}
10+
function TestPlan{T}(region, sz::NTuple{N,Int}) where {T,N}
11+
return new{T,N}(region, sz)
12+
end
13+
end
14+
15+
mutable struct InverseTestPlan{T,N} <: Plan{T}
16+
region
17+
sz::NTuple{N,Int}
18+
pinv::Plan{T}
19+
function InverseTestPlan{T}(region, sz::NTuple{N,Int}) where {T,N}
20+
return new{T,N}(region, sz)
21+
end
22+
end
23+
24+
Base.size(p::TestPlan) = p.sz
25+
Base.ndims(::TestPlan{T,N}) where {T,N} = N
26+
Base.size(p::InverseTestPlan) = p.sz
27+
Base.ndims(::InverseTestPlan{T,N}) where {T,N} = N
28+
29+
function AbstractFFTs.plan_fft(x::AbstractArray{T}, region; kwargs...) where {T}
30+
return TestPlan{T}(region, size(x))
31+
end
32+
function AbstractFFTs.plan_bfft(x::AbstractArray{T}, region; kwargs...) where {T}
33+
return InverseTestPlan{T}(region, size(x))
34+
end
35+
36+
function AbstractFFTs.plan_inv(p::TestPlan{T}) where {T}
37+
unscaled_pinv = InverseTestPlan{T}(p.region, p.sz)
38+
N = AbstractFFTs.normalization(T, p.sz, p.region)
39+
unscaled_pinv.pinv = AbstractFFTs.ScaledPlan(p, N)
40+
pinv = AbstractFFTs.ScaledPlan(unscaled_pinv, N)
41+
return pinv
42+
end
43+
function AbstractFFTs.plan_inv(pinv::InverseTestPlan{T}) where {T}
44+
unscaled_p = TestPlan{T}(pinv.region, pinv.sz)
45+
N = AbstractFFTs.normalization(T, pinv.sz, pinv.region)
46+
unscaled_p.pinv = AbstractFFTs.ScaledPlan(pinv, N)
47+
p = AbstractFFTs.ScaledPlan(unscaled_p, N)
48+
return p
49+
end
50+
51+
# Just a helper function since forward and backward are nearly identical
52+
# The function does not check if the size of `y` and `x` are compatible, this
53+
# is done in the function where `dft!` is called since the check differs for FFTs
54+
# with complex and real-valued signals
55+
function dft!(
56+
y::AbstractArray{<:Complex,N},
57+
x::AbstractArray{<:Union{Complex,Real},N},
58+
dims,
59+
sign::Int
60+
) where {N}
61+
# check that dimensions that are transformed are unique
62+
allunique(dims) || error("dimensions have to be unique")
63+
64+
T = eltype(y)
65+
# we use `size(x, d)` since for real-valued signals
66+
# `size(y, first(dims)) = size(x, first(dims)) ÷ 2 + 1`
67+
cs = map(d -> T(sign * 2π / size(x, d)), dims)
68+
fill!(y, zero(T))
69+
for yidx in CartesianIndices(y)
70+
# set of indices of `x` on which `y[yidx]` depends
71+
xindices = CartesianIndices(
72+
ntuple(i -> i in dims ? axes(x, i) : yidx[i]:yidx[i], Val(N))
73+
)
74+
for xidx in xindices
75+
y[yidx] += x[xidx] * cis(sum(c * (yidx[d] - 1) * (xidx[d] - 1) for (c, d) in zip(cs, dims)))
76+
end
77+
end
78+
return y
79+
end
80+
81+
function mul!(
82+
y::AbstractArray{<:Complex,N}, p::TestPlan, x::AbstractArray{<:Union{Complex,Real},N}
83+
) where {N}
84+
size(y) == size(p) == size(x) || throw(DimensionMismatch())
85+
dft!(y, x, p.region, -1)
86+
end
87+
function mul!(
88+
y::AbstractArray{<:Complex,N}, p::InverseTestPlan, x::AbstractArray{<:Union{Complex,Real},N}
89+
) where {N}
90+
size(y) == size(p) == size(x) || throw(DimensionMismatch())
91+
dft!(y, x, p.region, 1)
92+
end
93+
94+
Base.:*(p::TestPlan, x::AbstractArray) = mul!(similar(x, complex(float(eltype(x)))), p, x)
95+
Base.:*(p::InverseTestPlan, x::AbstractArray) = mul!(similar(x, complex(float(eltype(x)))), p, x)
96+
97+
mutable struct TestRPlan{T,N} <: Plan{T}
98+
region
99+
sz::NTuple{N,Int}
100+
pinv::Plan{T}
101+
TestRPlan{T}(region, sz::NTuple{N,Int}) where {T,N} = new{T,N}(region, sz)
102+
end
103+
104+
mutable struct InverseTestRPlan{T,N} <: Plan{T}
105+
d::Int
106+
region
107+
sz::NTuple{N,Int}
108+
pinv::Plan{T}
109+
function InverseTestRPlan{T}(d::Int, region, sz::NTuple{N,Int}) where {T,N}
110+
sz[first(region)::Int] == d ÷ 2 + 1 || error("incompatible dimensions")
111+
return new{T,N}(d, region, sz)
112+
end
113+
end
114+
115+
function AbstractFFTs.plan_rfft(x::AbstractArray{T}, region; kwargs...) where {T}
116+
return TestRPlan{T}(region, size(x))
117+
end
118+
function AbstractFFTs.plan_brfft(x::AbstractArray{T}, d, region; kwargs...) where {T}
119+
return InverseTestRPlan{T}(d, region, size(x))
120+
end
121+
function AbstractFFTs.plan_inv(p::TestRPlan{T,N}) where {T,N}
122+
firstdim = first(p.region)::Int
123+
d = p.sz[firstdim]
124+
sz = ntuple(i -> i == firstdim ? d ÷ 2 + 1 : p.sz[i], Val(N))
125+
_N = AbstractFFTs.normalization(T, p.sz, p.region)
126+
127+
unscaled_pinv = InverseTestRPlan{T}(d, p.region, sz)
128+
unscaled_pinv.pinv = AbstractFFTs.ScaledPlan(p, _N)
129+
pinv = AbstractFFTs.ScaledPlan(unscaled_pinv, _N)
130+
return pinv
131+
end
132+
133+
function AbstractFFTs.plan_inv(pinv::InverseTestRPlan{T,N}) where {T,N}
134+
firstdim = first(pinv.region)::Int
135+
sz = ntuple(i -> i == firstdim ? pinv.d : pinv.sz[i], Val(N))
136+
_N = AbstractFFTs.normalization(T, sz, pinv.region)
137+
138+
unscaled_p = TestRPlan{T}(pinv.region, sz)
139+
unscaled_p.pinv = AbstractFFTs.ScaledPlan(pinv, _N)
140+
p = AbstractFFTs.ScaledPlan(unscaled_p, _N)
141+
return p
142+
end
143+
144+
Base.size(p::TestRPlan) = p.sz
145+
Base.ndims(::TestRPlan{T,N}) where {T,N} = N
146+
Base.size(p::InverseTestRPlan) = p.sz
147+
Base.ndims(::InverseTestRPlan{T,N}) where {T,N} = N
148+
149+
function real_invdft!(
150+
y::AbstractArray{<:Real,N},
151+
x::AbstractArray{<:Union{Complex,Real},N},
152+
dims,
153+
) where {N}
154+
# check that dimensions that are transformed are unique
155+
allunique(dims) || error("dimensions have to be unique")
156+
157+
firstdim = first(dims)
158+
size_x_firstdim = size(x, firstdim)
159+
iseven_firstdim = iseven(size(y, firstdim))
160+
# we do not check that the input corresponds to a real-valued signal
161+
# (i.e., that the first and, if `iseven_firstdim`, the last value in dimension
162+
# `haldim` of `x` are real values) due to numerical inaccuracies
163+
# instead we just use the real part of these entries
164+
165+
T = eltype(y)
166+
# we use `size(y, d)` since `size(x, first(dims)) = size(y, first(dims)) ÷ 2 + 1`
167+
cs = map(d -> T(2π / size(y, d)), dims)
168+
fill!(y, zero(T))
169+
for yidx in CartesianIndices(y)
170+
# set of indices of `x` on which `y[yidx]` depends
171+
xindices = CartesianIndices(
172+
ntuple(i -> i in dims ? axes(x, i) : yidx[i]:yidx[i], Val(N))
173+
)
174+
for xidx in xindices
175+
coeffimag, coeffreal = sincos(
176+
sum(c * (yidx[d] - 1) * (xidx[d] - 1) for (c, d) in zip(cs, dims))
177+
)
178+
179+
# the first and, if `iseven_firstdim`, the last term of the DFT are scaled
180+
# with 1 instead of 2 and only the real part is used (see note above)
181+
xidx_firstdim = xidx[firstdim]
182+
if xidx_firstdim == 1 || (iseven_firstdim && xidx_firstdim == size_x_firstdim)
183+
y[yidx] += coeffreal * real(x[xidx])
184+
else
185+
xreal, ximag = reim(x[xidx])
186+
y[yidx] += 2 * (coeffreal * xreal - coeffimag * ximag)
187+
end
188+
end
189+
end
190+
191+
return y
192+
end
193+
194+
to_real!(x::AbstractArray) = map!(real, x, x)
195+
196+
function Base.:*(p::TestRPlan, x::AbstractArray)
197+
size(p) == size(x) || error("array and plan are not consistent")
198+
199+
# create output array
200+
firstdim = first(p.region)::Int
201+
d = size(x, firstdim)
202+
firstdim_size = d ÷ 2 + 1
203+
T = complex(float(eltype(x)))
204+
sz = ntuple(i -> i == firstdim ? firstdim_size : size(x, i), Val(ndims(x)))
205+
y = similar(x, T, sz)
206+
207+
# compute DFT
208+
dft!(y, x, p.region, -1)
209+
210+
# we clean the output a bit to make sure that we return real values
211+
# whenever the output is mathematically guaranteed to be a real number
212+
to_real!(selectdim(y, firstdim, 1))
213+
if iseven(d)
214+
to_real!(selectdim(y, firstdim, firstdim_size))
215+
end
216+
217+
return y
218+
end
219+
220+
function Base.:*(p::InverseTestRPlan, x::AbstractArray)
221+
size(p) == size(x) || error("array and plan are not consistent")
222+
223+
# create output array
224+
firstdim = first(p.region)::Int
225+
d = p.d
226+
sz = ntuple(i -> i == firstdim ? d : size(x, i), Val(ndims(x)))
227+
y = similar(x, real(float(eltype(x))), sz)
228+
229+
# compute DFT
230+
real_invdft!(y, x, p.region)
231+
232+
return y
233+
end
234+
235+
end

0 commit comments

Comments
 (0)