Skip to content

Commit 0ac21f0

Browse files
jishnubstevengj
andauthored
Improve type inference in plan_r2r (#253)
* Improve type inference in plan_r2r * skip complex test on MKL * _mapIntknd for region Tuples * simplify mapInt implementation * store value of K as a field * reduce unnecesary changes * update comment above constructor * avoid code duplication in fix_kinds * adjust whitespace * Fix type in comment Co-authored-by: Steven G. Johnson <stevenj@mit.edu> --------- Co-authored-by: Steven G. Johnson <stevenj@mit.edu>
1 parent 7412b3e commit 0ac21f0

File tree

2 files changed

+78
-20
lines changed

2 files changed

+78
-20
lines changed

src/fft.jl

Lines changed: 56 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -245,9 +245,9 @@ end
245245
# needed to determine whether it is applicable. We need to put
246246
# this into a type to support a finalizer on the fftw_plan.
247247
# K is FORWARD/BACKWARD for forward/backward or r2c/c2r plans, respectively.
248-
# For r2r plans, K is a tuple of the transform kinds along each dimension.
248+
# For r2r plans, the field kinds::K is a tuple/vector of the transform kinds along each dimension.
249249
abstract type FFTWPlan{T<:fftwNumber,K,inplace} <: Plan{T} end
250-
for P in (:cFFTWPlan, :rFFTWPlan, :r2rFFTWPlan) # complex, r2c/c2r, and r2r
250+
for P in (:cFFTWPlan, :rFFTWPlan) # complex, r2c/c2r
251251
@eval begin
252252
mutable struct $P{T<:fftwNumber,K,inplace,N,G} <: FFTWPlan{T,K,inplace}
253253
plan::PlanPtr
@@ -277,6 +277,34 @@ for P in (:cFFTWPlan, :rFFTWPlan, :r2rFFTWPlan) # complex, r2c/c2r, and r2r
277277
end
278278
end
279279

280+
mutable struct r2rFFTWPlan{T<:fftwNumber,K,inplace,N,G} <: FFTWPlan{T,K,inplace}
281+
plan::PlanPtr
282+
sz::NTuple{N,Int} # size of array on which plan operates (Int tuple)
283+
osz::NTuple{N,Int} # size of output array (Int tuple)
284+
istride::NTuple{N,Int} # strides of input
285+
ostride::NTuple{N,Int} # strides of output
286+
ialign::Int32 # alignment mod 16 of input
287+
oalign::Int32 # alignment mod 16 of input
288+
flags::UInt32 # planner flags
289+
region::G # region (iterable) of dims that are transformed
290+
kinds::K
291+
pinv::ScaledPlan
292+
function r2rFFTWPlan{T,K,inplace,N,G}(plan::PlanPtr, flags::Integer, R::G,
293+
X::StridedArray{T,N},
294+
Y::StridedArray, kinds::K) where {T<:fftwNumber,K,inplace,N,G}
295+
p = new(plan, size(X), size(Y), strides(X), strides(Y),
296+
alignment_of(X), alignment_of(Y), flags, R, kinds)
297+
finalizer(maybe_destroy_plan, p)
298+
p
299+
end
300+
end
301+
302+
function r2rFFTWPlan{T,K,inplace,N}(plan::PlanPtr, flags::Integer, R::G,
303+
X::StridedArray{T,N},
304+
Y::StridedArray, kinds::K) where {T<:fftwNumber,K,inplace,N,G}
305+
r2rFFTWPlan{T,K,inplace,N,G}(plan, flags, R, X, Y, kinds)
306+
end
307+
280308
size(p::FFTWPlan) = p.sz
281309

282310
unsafe_convert(::Type{PlanPtr}, p::FFTWPlan) = p.plan
@@ -427,16 +455,17 @@ function show(io::IO, p::rFFTWPlan{T,K,inplace}) where {T,K,inplace}
427455
end
428456

429457
function show(io::IO, p::r2rFFTWPlan{T,K,inplace}) where {T,K,inplace}
458+
kinds = p.kinds
430459
print(io, inplace ? "FFTW in-place r2r " : "FFTW r2r ")
431-
if isempty(K)
460+
if isempty(kinds)
432461
print(io, "0-dimensional")
433-
elseif K == ntuple(i -> K[1], length(K))
434-
print(io, kind2string(K[1]))
435-
if length(K) > 1
436-
print(io, "^", length(K))
462+
elseif kinds == ntuple(i -> kinds[1], length(kinds))
463+
print(io, kind2string(kinds[1]))
464+
if length(kinds) > 1
465+
print(io, "^", length(kinds))
437466
end
438467
else
439-
print(io, join(map(kind2string, K), "×"))
468+
print(io, join(map(kind2string, kinds), "×"))
440469
end
441470
print(io, " plan for ")
442471
showfftdims(io, p.sz, p.istride, T)
@@ -553,7 +582,6 @@ function dims_howmany(X::StridedArray, Y::StridedArray,
553582
return (dims, howmany)
554583
end
555584

556-
# check & convert kinds into int32 array with same length as region
557585
function fix_kinds(region, kinds)
558586
if length(kinds) != length(region)
559587
if length(kinds) > length(region)
@@ -575,9 +603,14 @@ function fix_kinds(region, kinds)
575603
end
576604
return k
577605
end
606+
fix_kinds(region::Tuple, kinds::Integer) = ntuple(_->Int32(kinds), length(region))
607+
fix_kinds(region::Tuple, kinds::Tuple{Integer}) = fix_kinds(region, kinds[1])
578608

579-
# low-level FFTWPlan creation (for internal use in FFTW module)
609+
# Potentially avoid an extra `collect`
610+
_collect(T, x) = collect(T, x)
611+
_collect(::Type{T}, x::AbstractVector) where {T} = convert(Vector{T}, x)
580612

613+
# low-level FFTWPlan creation (for internal use in FFTW module)
581614
for (Tr,Tc,fftw,lib) in ((:Float64,:(Complex{Float64}),"fftw",:libfftw3),
582615
(:Float32,:(Complex{Float32}),"fftwf",:libfftw3f))
583616
@eval @exclusive function cFFTWPlan{$Tc,K,inplace,N}(X::StridedArray{$Tc,N},
@@ -644,6 +677,7 @@ for (Tr,Tc,fftw,lib) in ((:Float64,:(Complex{Float64}),"fftw",:libfftw3),
644677
Y::StridedArray{$Tr,N},
645678
region, kinds, flags::Integer,
646679
timelimit::Real) where {inplace,N}
680+
647681
R = isa(region, Tuple) ? region : copy(region)
648682
knd = fix_kinds(region, kinds)
649683
unsafe_set_timelimit($Tr, timelimit)
@@ -653,19 +687,20 @@ for (Tr,Tc,fftw,lib) in ((:Float64,:(Complex{Float64}),"fftw",:libfftw3),
653687
(Int32, Ptr{Int}, Int32, Ptr{Int},
654688
Ptr{$Tr}, Ptr{$Tr}, Ptr{Int32}, UInt32),
655689
size(dims,2), dims, size(howmany,2), howmany,
656-
X, Y, knd, flags)
690+
X, Y, _collect(Int32, knd), flags)
657691
unsafe_set_timelimit($Tr, NO_TIMELIMIT)
658692
if plan == C_NULL
659693
error("FFTW could not create plan") # shouldn't normally happen
660694
end
661-
r2rFFTWPlan{$Tr,(map(Int,knd)...,),inplace,N}(plan, flags, R, X, Y)
695+
r2rFFTWPlan{$Tr,typeof(knd),inplace,N}(plan, flags, R, X, Y, knd)
662696
end
663697

664698
# support r2r transforms of complex = transforms of real & imag parts
665699
@eval @exclusive function r2rFFTWPlan{$Tc,Any,inplace,N}(X::StridedArray{$Tc,N},
666700
Y::StridedArray{$Tc,N},
667701
region, kinds, flags::Integer,
668702
timelimit::Real) where {inplace,N}
703+
669704
R = isa(region, Tuple) ? region : copy(region)
670705
knd = fix_kinds(region, kinds)
671706
unsafe_set_timelimit($Tr, timelimit)
@@ -678,14 +713,13 @@ for (Tr,Tc,fftw,lib) in ((:Float64,:(Complex{Float64}),"fftw",:libfftw3),
678713
(Int32, Ptr{Int}, Int32, Ptr{Int},
679714
Ptr{$Tc}, Ptr{$Tc}, Ptr{Int32}, UInt32),
680715
size(dims,2), dims, size(howmany,2), howmany,
681-
X, Y, knd, flags)
716+
X, Y, _collect(Int32, knd), flags)
682717
unsafe_set_timelimit($Tr, NO_TIMELIMIT)
683718
if plan == C_NULL
684719
error("FFTW could not create plan") # shouldn't normally happen
685720
end
686-
r2rFFTWPlan{$Tc,(map(Int,knd)...,),inplace,N}(plan, flags, R, X, Y)
721+
r2rFFTWPlan{$Tc,typeof(knd),inplace,N}(plan, flags, R, X, Y, knd)
687722
end
688-
689723
end
690724

691725
# Convert arrays of numeric types to FFTW-supported packed complex-float types
@@ -888,15 +922,16 @@ end
888922

889923
# FFTW r2r transforms (low-level interface)
890924

925+
_ntupleid(v) = ntuple(identity, v)
891926
for f in (:r2r, :r2r!)
892927
pf = Symbol("plan_", f)
893928
@eval begin
894929
$f(x::AbstractArray{<:fftwNumber}, kinds) = $pf(x, kinds) * x
895930
$f(x::AbstractArray{<:fftwNumber}, kinds, region) = $pf(x, kinds, region) * x
896-
$pf(x::AbstractArray, kinds; kws...) = $pf(x, kinds, 1:ndims(x); kws...)
897-
$f(x::AbstractArray{<:Real}, kinds, region=1:ndims(x)) = $f(fftwfloat(x), kinds, region)
931+
$pf(x::AbstractArray, kinds; kws...) = $pf(x, kinds, _ntupleid(Val(ndims(x))); kws...)
932+
$f(x::AbstractArray{<:Real}, kinds, region=_ntupleid(Val(ndims(x)))) = $f(fftwfloat(x), kinds, region)
898933
$pf(x::AbstractArray{<:Real}, kinds, region; kws...) = $pf(fftwfloat(x), kinds, region; kws...)
899-
$f(x::AbstractArray{<:Complex}, kinds, region=1:ndims(x)) = $f(fftwcomplex(x), kinds, region)
934+
$f(x::AbstractArray{<:Complex}, kinds, region=_ntupleid(Val(ndims(x)))) = $f(fftwcomplex(x), kinds, region)
900935
$pf(x::AbstractArray{<:Complex}, kinds, region; kws...) = $pf(fftwcomplex(x), kinds, region; kws...)
901936
end
902937
end
@@ -955,13 +990,14 @@ function plan_inv(p::r2rFFTWPlan{T,K,inplace,N};
955990
return plan
956991
end
957992
X = Array{T}(undef, p.sz)
958-
iK = fix_kinds(p.region, [inv_kind[k] for k in K])
993+
# broadcast getindex to preserve tuples
994+
iK = fix_kinds(p.region, getindex.((inv_kind,), p.kinds))
959995
Y = inplace ? X : fakesimilar(p.flags, X, T)
960996
ScaledPlan(r2rFFTWPlan{T,Any,inplace,N}(X, Y, p.region, iK,
961997
p.flags, NO_TIMELIMIT),
962998
normalization(real(T),
963999
map(logical_size, [p.sz...][[p.region...]], iK),
964-
1:length(iK)))
1000+
1:length(p.region)))
9651001
end
9661002

9671003
function mul!(y::StridedArray{T}, p::r2rFFTWPlan{T}, x::StridedArray{T}) where T

test/runtests.jl

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -552,3 +552,25 @@ end
552552
end
553553
@test FFTW.get_num_threads() == 2 # Unchanged
554554
end
555+
556+
@testset "type-inference in r2r plans" begin
557+
# Compare with definition
558+
function testr2r(::Type{T}) where {T}
559+
n = 4
560+
v = T[1:n;]
561+
plan = @inferred (() -> FFTW.plan_r2r(v, FFTW.REDFT10))()
562+
w = plan * v
563+
@test w [2sum(j->v[j+1]*cos(pi*(j+1/2)*k/n), 0:n-1) for k in 0:n-1]
564+
invplan = @inferred FFTW.plan_inv(plan)
565+
@test invplan * w v
566+
end
567+
@testset for T in (Float32, Float64)
568+
testr2r(T)
569+
end
570+
# complex r2r is broken on mkl
571+
if FFTW.get_provider() == "fftw"
572+
@testset for T in (ComplexF32, ComplexF64)
573+
testr2r(T)
574+
end
575+
end
576+
end

0 commit comments

Comments
 (0)