diff --git a/lib/cufft/fft.jl b/lib/cufft/fft.jl index c235ac2b93..8b173f59d1 100644 --- a/lib/cufft/fft.jl +++ b/lib/cufft/fft.jl @@ -152,12 +152,23 @@ end # region is an iterable subset of dimensions # spec. an integer, range, tuple, or array +# try to constant-propagate the `region` argument when it is not a tuple. This helps with +# inference of calls like plan_fft(X), which is translated by AbstractFFTs.jl into +# plan_fft(X, 1:ndims(X)). +for f in (:plan_fft!, :plan_bfft!, :plan_fft, :plan_bfft) + @eval begin + Base.@constprop :aggressive function $f(X::DenseCuArray{T,N}, region) where {T<:cufftComplexes,N} + R = length(region) + region = NTuple{R,Int}(region) + $f(X, region) + end + end +end + # inplace complex -function plan_fft!(X::DenseCuArray{T,N}, region) where {T<:cufftComplexes,N} +function plan_fft!(X::DenseCuArray{T,N}, region::NTuple{R,Int}) where {T<:cufftComplexes,N,R} K = CUFFT_FORWARD inplace = true - R = length(region) - region = NTuple{R,Int}(region) md = plan_max_dims(region, size(X)) sizex = size(X)[1:md] @@ -166,11 +177,9 @@ function plan_fft!(X::DenseCuArray{T,N}, region) where {T<:cufftComplexes,N} CuFFTPlan{T,T,K,inplace,N,R,Nothing}(handle, X, size(X), region, nothing) end -function plan_bfft!(X::DenseCuArray{T,N}, region) where {T<:cufftComplexes,N} +function plan_bfft!(X::DenseCuArray{T,N}, region::NTuple{R,Int}) where {T<:cufftComplexes,N,R} K = CUFFT_INVERSE inplace = true - R = length(region) - region = NTuple{R,Int}(region) md = plan_max_dims(region, size(X)) sizex = size(X)[1:md] @@ -180,11 +189,9 @@ function plan_bfft!(X::DenseCuArray{T,N}, region) where {T<:cufftComplexes,N} end # out-of-place complex -function plan_fft(X::DenseCuArray{T,N}, region) where {T<:cufftComplexes,N} +function plan_fft(X::DenseCuArray{T,N}, region::NTuple{R,Int}) where {T<:cufftComplexes,N,R} K = CUFFT_FORWARD inplace = false - R = length(region) - region = NTuple{R,Int}(region) md = plan_max_dims(region,size(X)) sizex = size(X)[1:md] @@ -193,11 +200,9 @@ function plan_fft(X::DenseCuArray{T,N}, region) where {T<:cufftComplexes,N} CuFFTPlan{T,T,K,inplace,N,R,Nothing}(handle, X, size(X), region, nothing) end -function plan_bfft(X::DenseCuArray{T,N}, region) where {T<:cufftComplexes,N} +function plan_bfft(X::DenseCuArray{T,N}, region::NTuple{R,Int}) where {T<:cufftComplexes,N,R} K = CUFFT_INVERSE inplace = false - R = length(region) - region = NTuple{R,Int}(region) md = plan_max_dims(region,size(X)) sizex = size(X)[1:md] @@ -207,19 +212,23 @@ function plan_bfft(X::DenseCuArray{T,N}, region) where {T<:cufftComplexes,N} end # out-of-place real-to-complex -function plan_rfft(X::DenseCuArray{T,N}, region) where {T<:cufftReals,N} - K = CUFFT_FORWARD - inplace = false +Base.@constprop :aggressive function plan_rfft(X::DenseCuArray{T,N}, region) where {T<:cufftReals,N} R = length(region) region = NTuple{R,Int}(region) + plan_rfft(X, region) +end + +function plan_rfft(X::DenseCuArray{T,N}, region::NTuple{R,Int}) where {T<:cufftReals,N,R} + K = CUFFT_FORWARD + inplace = false md = plan_max_dims(region,size(X)) sizex = size(X)[1:md] handle = cufftGetPlan(complex(T), T, sizex, region) - ydims = collect(size(X)) - ydims[region[1]] = div(ydims[region[1]], 2) + 1 + xdims = size(X) + ydims = Base.setindex(xdims, div(xdims[region[1]], 2) + 1, region[1]) # The buffer is not needed for real-to-complex (`mul!`), # but it’s required for complex-to-real (`ldiv!`). @@ -230,21 +239,24 @@ function plan_rfft(X::DenseCuArray{T,N}, region) where {T<:cufftReals,N} end # out-of-place complex-to-real -function plan_brfft(X::DenseCuArray{T,N}, d::Integer, region) where {T<:cufftComplexes,N} - K = CUFFT_INVERSE - inplace = false +Base.@constprop :aggressive function plan_brfft(X::DenseCuArray{T,N}, d::Integer, region) where {T<:cufftComplexes,N} R = length(region) region = NTuple{R,Int}(region) + plan_brfft(X, d, region) +end - ydims = collect(size(X)) - ydims[region[1]] = d +function plan_brfft(X::DenseCuArray{T,N}, d::Integer, region::NTuple{R,Int}) where {T<:cufftComplexes,N,R} + K = CUFFT_INVERSE + inplace = false - handle = cufftGetPlan(real(T), T, (ydims...,), region) + xdims = size(X) + ydims = Base.setindex(xdims, d, region[1]) + handle = cufftGetPlan(real(T), T, ydims, region) buffer = CuArray{T}(undef, size(X)) B = typeof(buffer) - CuFFTPlan{real(T),T,K,inplace,N,R,B}(handle, size(X), (ydims...,), region, buffer) + CuFFTPlan{real(T),T,K,inplace,N,R,B}(handle, size(X), ydims, region, buffer) end diff --git a/test/libraries/cufft.jl b/test/libraries/cufft.jl index af200e5a61..6dc3212959 100644 --- a/test/libraries/cufft.jl +++ b/test/libraries/cufft.jl @@ -115,7 +115,7 @@ atol(::Type{Complex{T}}) where {T} = atol(T) function out_of_place(X::AbstractArray{T,N}) where {T <: Complex,N} fftw_X = fft(X) d_X = CuArray(X) - p = plan_fft(d_X) + p = @inferred plan_fft(d_X) d_Y = p * d_X Y = collect(d_Y) @test isapprox(Y, fftw_X, rtol = rtol(T), atol = atol(T)) @@ -130,12 +130,16 @@ function out_of_place(X::AbstractArray{T,N}) where {T <: Complex,N} Z = collect(d_Z) @test isapprox(Z, X, rtol = rtol(T), atol = atol(T)) + pinvb = @inferred plan_bfft(d_Y) + d_Z = pinvb * d_Y + Z = collect(d_Z) ./ length(d_Z) + @test isapprox(Z, X, rtol = rtol(T), atol = atol(T)) end function in_place(X::AbstractArray{T,N}) where {T <: Complex,N} fftw_X = fft(X) d_X = CuArray(X) - p = plan_fft!(d_X) + p = @inferred plan_fft!(d_X) p * d_X Y = collect(d_X) @test isapprox(Y, fftw_X, rtol = rtol(T), atol = atol(T)) @@ -144,6 +148,12 @@ function in_place(X::AbstractArray{T,N}) where {T <: Complex,N} pinv * d_X Z = collect(d_X) @test isapprox(Z, X, rtol = rtol(T), atol = atol(T)) + p * d_X + + pinvb = @inferred plan_bfft!(d_X) + pinvb * d_X + Z = collect(d_X) ./ length(X) + @test isapprox(Z, X, rtol = rtol(T), atol = atol(T)) end function batched(X::AbstractArray{T,N},region) where {T <: Complex,N} @@ -261,7 +271,7 @@ end function out_of_place(X::AbstractArray{T,N}) where {T <: Real,N} fftw_X = rfft(X) d_X = CuArray(X) - p = plan_rfft(d_X) + p = @inferred plan_rfft(d_X) d_Y = p * d_X Y = collect(d_Y) @test isapprox(Y, fftw_X, rtol = rtol(T), atol = atol(T)) @@ -280,6 +290,11 @@ function out_of_place(X::AbstractArray{T,N}) where {T <: Real,N} d_W = pinv3 * d_X W = collect(d_W) @test isapprox(W, Y, rtol = rtol(T), atol = atol(T)) + + pinvb = @inferred plan_brfft(d_Y,size(X,1)) + d_Z = pinvb * d_Y + Z = collect(d_Z) ./ length(X) + @test isapprox(Z, X, rtol = rtol(T), atol = atol(T)) end function batched(X::AbstractArray{T,N},region) where {T <: Real,N}