Skip to content

Commit e4a00b1

Browse files
jishnubstevengj
andauthored
reduce allocations in dims_howmany (#269)
* reduce allocations in dims_howmany * Update src/fft.jl Co-authored-by: Steven G. Johnson <stevenj@mit.edu> * Dont collect size tuple * filter for Int/Tuple regions * use tuple instead of vector region at more places * remove unused methods * test region collections * bump version to v1.7.0 --------- Co-authored-by: Steven G. Johnson <stevenj@mit.edu>
1 parent 82a99dc commit e4a00b1

File tree

3 files changed

+66
-26
lines changed

3 files changed

+66
-26
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "FFTW"
22
uuid = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341"
3-
version = "1.6.1"
3+
version = "1.7.0"
44

55
[deps]
66
AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c"

src/fft.jl

Lines changed: 62 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -566,20 +566,51 @@ unsafe_execute!(plan::r2rFFTWPlan{T},
566566
# re-use the table of trigonometric constants from the first plan.
567567

568568
# Compute dims and howmany for FFTW guru planner
569-
function dims_howmany(X::StridedArray, Y::StridedArray,
570-
sz::Vector{Int}, region)
571-
reg = Int[region...]::Vector{Int}
572-
if length(unique(reg)) < length(reg)
569+
_anyrepeated(::Union{Number, AbstractUnitRange}) = false
570+
function _anyrepeated(region)
571+
any(region) do x
572+
count(==(x), region) > 1
573+
end
574+
end
575+
576+
# Utility methods to reduce allocations in dims_howmany
577+
@inline _setindex(oreg, v, n) = (oreg[n] = v; oreg)
578+
@inline _setindex(oreg::Tuple, v, n) = Base.setindex(oreg, v, n)
579+
@inline _filtercoll(region::Union{Int, Tuple}, len) = ntuple(zero, len)
580+
@inline _filtercoll(region, len) = Vector{Int}(undef, len)
581+
# Optimized filter(∉(region), 1:ndims(X))
582+
function _filter_notin_region(region, ::Val{ndimsX}) where {ndimsX}
583+
oreg = _filtercoll(region, ndimsX - length(region))
584+
n = 1
585+
for dim in 1:ndimsX
586+
dim in region && continue
587+
oreg = _setindex(oreg, dim, n)
588+
n += 1
589+
end
590+
oreg
591+
end
592+
function dims_howmany(X::StridedArray, Y::StridedArray, sz, region)
593+
if _anyrepeated(region)
573594
throw(ArgumentError("each dimension can be transformed at most once"))
574595
end
575-
ist = [strides(X)...]
576-
ost = [strides(Y)...]
577-
dims = Matrix(transpose([sz[reg] ist[reg] ost[reg]]))
578-
oreg = [1:ndims(X);]
579-
oreg[reg] .= 0
580-
oreg = filter(d -> d > 0, oreg)
581-
howmany = Matrix(transpose([sz[oreg] ist[oreg] ost[oreg]]))
582-
return (dims, howmany)
596+
ist = strides(X)
597+
ost = strides(Y)
598+
dims = Matrix{Int}(undef, 3, length(region))
599+
for (ind, i) in enumerate(region)
600+
dims[1, ind] = sz[i]
601+
dims[2, ind] = ist[i]
602+
dims[3, ind] = ost[i]
603+
end
604+
605+
oreg = _filter_notin_region(region, Val(ndims(X)))
606+
howmany = Matrix{Int}(undef, 3, length(oreg))
607+
for (ind, i) in enumerate(oreg)
608+
howmany[1, ind] = sz[i]
609+
howmany[2, ind] = ist[i]
610+
howmany[3, ind] = ost[i]
611+
end
612+
613+
return dims, howmany
583614
end
584615

585616
function fix_kinds(region, kinds)
@@ -604,6 +635,10 @@ function fix_kinds(region, kinds)
604635
return k
605636
end
606637

638+
_circshiftmin1(v) = circshift(collect(Int, v), -1)
639+
_circshiftmin1(t::Tuple) = (t[2:end]..., t[1])
640+
_circshiftmin1(x::Integer) = x
641+
607642
# low-level FFTWPlan creation (for internal use in FFTW module)
608643
for (Tr,Tc,fftw,lib) in ((:Float64,:(Complex{Float64}),"fftw",:libfftw3),
609644
(:Float32,:(Complex{Float32}),"fftwf",:libfftw3f))
@@ -613,7 +648,7 @@ for (Tr,Tc,fftw,lib) in ((:Float64,:(Complex{Float64}),"fftw",:libfftw3),
613648
direction = K
614649
unsafe_set_timelimit($Tr, timelimit)
615650
R = isa(region, Tuple) ? region : copy(region)
616-
dims, howmany = dims_howmany(X, Y, [size(X)...], R)
651+
dims, howmany = dims_howmany(X, Y, size(X), R)
617652
plan = ccall(($(string(fftw,"_plan_guru64_dft")),$lib[]),
618653
PlanPtr,
619654
(Int32, Ptr{Int}, Int32, Ptr{Int},
@@ -631,9 +666,9 @@ for (Tr,Tc,fftw,lib) in ((:Float64,:(Complex{Float64}),"fftw",:libfftw3),
631666
Y::StridedArray{$Tc,N},
632667
region, flags::Integer, timelimit::Real) where {inplace,N}
633668
R = isa(region, Tuple) ? region : copy(region)
634-
region = circshift(Int[region...],-1) # FFTW halves last dim
669+
regionshft = _circshiftmin1(region) # FFTW halves last dim
635670
unsafe_set_timelimit($Tr, timelimit)
636-
dims, howmany = dims_howmany(X, Y, [size(X)...], region)
671+
dims, howmany = dims_howmany(X, Y, size(X), regionshft)
637672
plan = ccall(($(string(fftw,"_plan_guru64_dft_r2c")),$lib[]),
638673
PlanPtr,
639674
(Int32, Ptr{Int}, Int32, Ptr{Int},
@@ -651,9 +686,9 @@ for (Tr,Tc,fftw,lib) in ((:Float64,:(Complex{Float64}),"fftw",:libfftw3),
651686
Y::StridedArray{$Tr,N},
652687
region, flags::Integer, timelimit::Real) where {inplace,N}
653688
R = isa(region, Tuple) ? region : copy(region)
654-
region = circshift(Int[region...],-1) # FFTW halves last dim
689+
regionshft = _circshiftmin1(region) # FFTW halves last dim
655690
unsafe_set_timelimit($Tr, timelimit)
656-
dims, howmany = dims_howmany(X, Y, [size(Y)...], region)
691+
dims, howmany = dims_howmany(X, Y, size(Y), regionshft)
657692
plan = ccall(($(string(fftw,"_plan_guru64_dft_c2r")),$lib[]),
658693
PlanPtr,
659694
(Int32, Ptr{Int}, Int32, Ptr{Int},
@@ -675,7 +710,7 @@ for (Tr,Tc,fftw,lib) in ((:Float64,:(Complex{Float64}),"fftw",:libfftw3),
675710
R = isa(region, Tuple) ? region : copy(region)
676711
knd = fix_kinds(region, kinds)
677712
unsafe_set_timelimit($Tr, timelimit)
678-
dims, howmany = dims_howmany(X, Y, [size(X)...], region)
713+
dims, howmany = dims_howmany(X, Y, size(X), region)
679714
plan = ccall(($(string(fftw,"_plan_guru64_r2r")),$lib[]),
680715
PlanPtr,
681716
(Int32, Ptr{Int}, Int32, Ptr{Int},
@@ -698,9 +733,11 @@ for (Tr,Tc,fftw,lib) in ((:Float64,:(Complex{Float64}),"fftw",:libfftw3),
698733
R = isa(region, Tuple) ? region : copy(region)
699734
knd = fix_kinds(region, kinds)
700735
unsafe_set_timelimit($Tr, timelimit)
701-
dims, howmany = dims_howmany(X, Y, [size(X)...], region)
702-
dims[2:3, 1:size(dims,2)] *= 2
703-
howmany[2:3, 1:size(howmany,2)] *= 2
736+
dims, howmany = dims_howmany(X, Y, size(X), region)
737+
@views begin
738+
dims[2:3, :] .*= 2
739+
howmany[2:3, :] .*= 2
740+
end
704741
howmany = [howmany [2,1,1]] # append loop over real/imag parts
705742
plan = ccall(($(string(fftw,"_plan_guru64_r2r")),$lib[]),
706743
PlanPtr,
@@ -759,9 +796,9 @@ for (f,direction) in ((:fft,FORWARD), (:bfft,BACKWARD))
759796
cFFTWPlan{T,$direction,true,N}(X, X, region, flags, timelimit)
760797
end
761798
$plan_f(X::StridedArray{<:fftwComplex}; kws...) =
762-
$plan_f(X, 1:ndims(X); kws...)
799+
$plan_f(X, ntuple(identity, ndims(X)); kws...)
763800
$plan_f!(X::StridedArray{<:fftwComplex}; kws...) =
764-
$plan_f!(X, 1:ndims(X); kws...)
801+
$plan_f!(X, ntuple(identity, ndims(X)); kws...)
765802

766803
function plan_inv(p::cFFTWPlan{T,$direction,inplace,N};
767804
num_threads::Union{Nothing, Integer} = nothing) where {T<:fftwComplex,N,inplace}
@@ -845,8 +882,8 @@ for (Tr,Tc) in ((:Float32,:(Complex{Float32})),(:Float64,:(Complex{Float64})))
845882
end
846883
end
847884

848-
plan_rfft(X::StridedArray{$Tr};kws...)=plan_rfft(X,1:ndims(X);kws...)
849-
plan_brfft(X::StridedArray{$Tr};kws...)=plan_brfft(X,1:ndims(X);kws...)
885+
plan_rfft(X::StridedArray{$Tr};kws...)=plan_rfft(X,ntuple(identity, ndims(X));kws...)
886+
plan_brfft(X::StridedArray{$Tr};kws...)=plan_brfft(X,ntuple(identity, ndims(X));kws...)
850887

851888
function plan_inv(p::rFFTWPlan{$Tr,$FORWARD,false,N},
852889
num_threads::Union{Nothing, Integer} = nothing) where N

test/runtests.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,9 @@ true_fftd3_m3d[:,:,2] .= -15
181181
end
182182

183183
@testset "rfft/rfftn" begin
184+
# Test regions as int/collection
185+
@test rfft(m4,1) == rfft(m4,1:1) == rfft(m4,(1,)) == rfft(m4, [1])
186+
184187
rfft_m4 = rfft(m4,1)
185188
rfftd2_m4 = rfft(m4,2)
186189
rfftn_m4 = rfft(m4)

0 commit comments

Comments
 (0)