Skip to content

Commit a9c93b1

Browse files
committed
Support mul! for real plans in tests
1 parent 5482609 commit a9c93b1

File tree

1 file changed

+26
-13
lines changed

1 file changed

+26
-13
lines changed

test/TestPlans.jl

Lines changed: 26 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
module TestPlans
22

33
import AbstractFFTs
4+
import LinearAlgebra.mul!
45
using AbstractFFTs: Plan
56

67
mutable struct TestPlan{T,N} <: Plan{T}
@@ -193,6 +194,29 @@ end
193194

194195
to_real!(x::AbstractArray) = map!(real, x, x)
195196

197+
function mul!(y::AbstractArray, p::TestRPlan, x::AbstractArray)
198+
# compute DFT
199+
dft!(y, x, p.region, -1)
200+
201+
# we clean the output a bit to make sure that we return real values
202+
# whenever the output is mathematically guaranteed to be a real number
203+
firstdim = first(p.region)::Int
204+
d = size(x, firstdim)
205+
firstdim_size = d ÷ 2 + 1
206+
to_real!(selectdim(y, firstdim, 1))
207+
if iseven(d)
208+
to_real!(selectdim(y, firstdim, firstdim_size))
209+
end
210+
211+
return y
212+
end
213+
214+
function mul!(y::AbstractArray, p::InverseTestRPlan, x::AbstractArray)
215+
# compute DFT
216+
real_invdft!(y, x, p.region)
217+
return y
218+
end
219+
196220
function Base.:*(p::TestRPlan, x::AbstractArray)
197221
size(p) == size(x) || error("array and plan are not consistent")
198222

@@ -205,16 +229,7 @@ function Base.:*(p::TestRPlan, x::AbstractArray)
205229
y = similar(x, T, sz)
206230

207231
# compute DFT
208-
dft!(y, x, p.region, -1)
209-
210-
# we clean the output a bit to make sure that we return real values
211-
# whenever the output is mathematically guaranteed to be a real number
212-
to_real!(selectdim(y, firstdim, 1))
213-
if iseven(d)
214-
to_real!(selectdim(y, firstdim, firstdim_size))
215-
end
216-
217-
return y
232+
mul!(y, p, x)
218233
end
219234

220235
function Base.:*(p::InverseTestRPlan, x::AbstractArray)
@@ -227,9 +242,7 @@ function Base.:*(p::InverseTestRPlan, x::AbstractArray)
227242
y = similar(x, real(float(eltype(x))), sz)
228243

229244
# compute DFT
230-
real_invdft!(y, x, p.region)
231-
232-
return y
245+
mul!(y, p, x)
233246
end
234247

235248
end

0 commit comments

Comments
 (0)