From 74e8a314a659b709d7284ba0dd2b7fb6d2b187a7 Mon Sep 17 00:00:00 2001 From: Juan Ignacio Polanco Date: Thu, 6 Mar 2025 10:36:38 +0100 Subject: [PATCH 1/2] Fix inference of FFT plan creation --- lib/cufft/fft.jl | 58 ++++++++++++++++++++++++----------------- test/libraries/cufft.jl | 21 ++++++++++++--- 2 files changed, 52 insertions(+), 27 deletions(-) diff --git a/lib/cufft/fft.jl b/lib/cufft/fft.jl index c235ac2b93..ff0b09e2c4 100644 --- a/lib/cufft/fft.jl +++ b/lib/cufft/fft.jl @@ -152,12 +152,21 @@ end # region is an iterable subset of dimensions # spec. an integer, range, tuple, or array +# convert `region` to a tuple within an inlined function to help constant propagation +for f in (:plan_fft!, :plan_bfft!, :plan_fft, :plan_bfft) + @eval begin + @inline function $f(X::DenseCuArray{T,N}, region) where {T<:cufftComplexes,N} + R = length(region) + region = NTuple{R,Int}(region) + $f(X, region) + end + end +end + # inplace complex -function plan_fft!(X::DenseCuArray{T,N}, region) where {T<:cufftComplexes,N} +function plan_fft!(X::DenseCuArray{T,N}, region::NTuple{R,Int}) where {T<:cufftComplexes,N,R} K = CUFFT_FORWARD inplace = true - R = length(region) - region = NTuple{R,Int}(region) md = plan_max_dims(region, size(X)) sizex = size(X)[1:md] @@ -166,11 +175,9 @@ function plan_fft!(X::DenseCuArray{T,N}, region) where {T<:cufftComplexes,N} CuFFTPlan{T,T,K,inplace,N,R,Nothing}(handle, X, size(X), region, nothing) end -function plan_bfft!(X::DenseCuArray{T,N}, region) where {T<:cufftComplexes,N} +function plan_bfft!(X::DenseCuArray{T,N}, region::NTuple{R,Int}) where {T<:cufftComplexes,N,R} K = CUFFT_INVERSE inplace = true - R = length(region) - region = NTuple{R,Int}(region) md = plan_max_dims(region, size(X)) sizex = size(X)[1:md] @@ -180,11 +187,9 @@ function plan_bfft!(X::DenseCuArray{T,N}, region) where {T<:cufftComplexes,N} end # out-of-place complex -function plan_fft(X::DenseCuArray{T,N}, region) where {T<:cufftComplexes,N} +function plan_fft(X::DenseCuArray{T,N}, region::NTuple{R,Int}) where {T<:cufftComplexes,N,R} K = CUFFT_FORWARD inplace = false - R = length(region) - region = NTuple{R,Int}(region) md = plan_max_dims(region,size(X)) sizex = size(X)[1:md] @@ -193,11 +198,9 @@ function plan_fft(X::DenseCuArray{T,N}, region) where {T<:cufftComplexes,N} CuFFTPlan{T,T,K,inplace,N,R,Nothing}(handle, X, size(X), region, nothing) end -function plan_bfft(X::DenseCuArray{T,N}, region) where {T<:cufftComplexes,N} +function plan_bfft(X::DenseCuArray{T,N}, region::NTuple{R,Int}) where {T<:cufftComplexes,N,R} K = CUFFT_INVERSE inplace = false - R = length(region) - region = NTuple{R,Int}(region) md = plan_max_dims(region,size(X)) sizex = size(X)[1:md] @@ -207,19 +210,23 @@ function plan_bfft(X::DenseCuArray{T,N}, region) where {T<:cufftComplexes,N} end # out-of-place real-to-complex -function plan_rfft(X::DenseCuArray{T,N}, region) where {T<:cufftReals,N} - K = CUFFT_FORWARD - inplace = false +@inline function plan_rfft(X::DenseCuArray{T,N}, region) where {T<:cufftReals,N} R = length(region) region = NTuple{R,Int}(region) + plan_rfft(X, region) +end + +function plan_rfft(X::DenseCuArray{T,N}, region::NTuple{R,Int}) where {T<:cufftReals,N,R} + K = CUFFT_FORWARD + inplace = false md = plan_max_dims(region,size(X)) sizex = size(X)[1:md] handle = cufftGetPlan(complex(T), T, sizex, region) - ydims = collect(size(X)) - ydims[region[1]] = div(ydims[region[1]], 2) + 1 + xdims = size(X) + ydims = Base.setindex(xdims, div(xdims[region[1]], 2) + 1, region[1]) # The buffer is not needed for real-to-complex (`mul!`), # but it’s required for complex-to-real (`ldiv!`). @@ -230,21 +237,24 @@ function plan_rfft(X::DenseCuArray{T,N}, region) where {T<:cufftReals,N} end # out-of-place complex-to-real -function plan_brfft(X::DenseCuArray{T,N}, d::Integer, region) where {T<:cufftComplexes,N} - K = CUFFT_INVERSE - inplace = false +@inline function plan_brfft(X::DenseCuArray{T,N}, d::Integer, region) where {T<:cufftComplexes,N} R = length(region) region = NTuple{R,Int}(region) + plan_brfft(X, d, region) +end - ydims = collect(size(X)) - ydims[region[1]] = d +function plan_brfft(X::DenseCuArray{T,N}, d::Integer, region::NTuple{R,Int}) where {T<:cufftComplexes,N,R} + K = CUFFT_INVERSE + inplace = false - handle = cufftGetPlan(real(T), T, (ydims...,), region) + xdims = size(X) + ydims = Base.setindex(xdims, d, region[1]) + handle = cufftGetPlan(real(T), T, ydims, region) buffer = CuArray{T}(undef, size(X)) B = typeof(buffer) - CuFFTPlan{real(T),T,K,inplace,N,R,B}(handle, size(X), (ydims...,), region, buffer) + CuFFTPlan{real(T),T,K,inplace,N,R,B}(handle, size(X), ydims, region, buffer) end diff --git a/test/libraries/cufft.jl b/test/libraries/cufft.jl index af200e5a61..6dc3212959 100644 --- a/test/libraries/cufft.jl +++ b/test/libraries/cufft.jl @@ -115,7 +115,7 @@ atol(::Type{Complex{T}}) where {T} = atol(T) function out_of_place(X::AbstractArray{T,N}) where {T <: Complex,N} fftw_X = fft(X) d_X = CuArray(X) - p = plan_fft(d_X) + p = @inferred plan_fft(d_X) d_Y = p * d_X Y = collect(d_Y) @test isapprox(Y, fftw_X, rtol = rtol(T), atol = atol(T)) @@ -130,12 +130,16 @@ function out_of_place(X::AbstractArray{T,N}) where {T <: Complex,N} Z = collect(d_Z) @test isapprox(Z, X, rtol = rtol(T), atol = atol(T)) + pinvb = @inferred plan_bfft(d_Y) + d_Z = pinvb * d_Y + Z = collect(d_Z) ./ length(d_Z) + @test isapprox(Z, X, rtol = rtol(T), atol = atol(T)) end function in_place(X::AbstractArray{T,N}) where {T <: Complex,N} fftw_X = fft(X) d_X = CuArray(X) - p = plan_fft!(d_X) + p = @inferred plan_fft!(d_X) p * d_X Y = collect(d_X) @test isapprox(Y, fftw_X, rtol = rtol(T), atol = atol(T)) @@ -144,6 +148,12 @@ function in_place(X::AbstractArray{T,N}) where {T <: Complex,N} pinv * d_X Z = collect(d_X) @test isapprox(Z, X, rtol = rtol(T), atol = atol(T)) + p * d_X + + pinvb = @inferred plan_bfft!(d_X) + pinvb * d_X + Z = collect(d_X) ./ length(X) + @test isapprox(Z, X, rtol = rtol(T), atol = atol(T)) end function batched(X::AbstractArray{T,N},region) where {T <: Complex,N} @@ -261,7 +271,7 @@ end function out_of_place(X::AbstractArray{T,N}) where {T <: Real,N} fftw_X = rfft(X) d_X = CuArray(X) - p = plan_rfft(d_X) + p = @inferred plan_rfft(d_X) d_Y = p * d_X Y = collect(d_Y) @test isapprox(Y, fftw_X, rtol = rtol(T), atol = atol(T)) @@ -280,6 +290,11 @@ function out_of_place(X::AbstractArray{T,N}) where {T <: Real,N} d_W = pinv3 * d_X W = collect(d_W) @test isapprox(W, Y, rtol = rtol(T), atol = atol(T)) + + pinvb = @inferred plan_brfft(d_Y,size(X,1)) + d_Z = pinvb * d_Y + Z = collect(d_Z) ./ length(X) + @test isapprox(Z, X, rtol = rtol(T), atol = atol(T)) end function batched(X::AbstractArray{T,N},region) where {T <: Real,N} From ec550270df1fa092705a935ed9e659c7bbee2b6e Mon Sep 17 00:00:00 2001 From: Juan Ignacio Polanco Date: Fri, 7 Mar 2025 09:16:37 +0100 Subject: [PATCH 2/2] Use @constprop :aggressive --- lib/cufft/fft.jl | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/lib/cufft/fft.jl b/lib/cufft/fft.jl index ff0b09e2c4..8b173f59d1 100644 --- a/lib/cufft/fft.jl +++ b/lib/cufft/fft.jl @@ -152,10 +152,12 @@ end # region is an iterable subset of dimensions # spec. an integer, range, tuple, or array -# convert `region` to a tuple within an inlined function to help constant propagation +# try to constant-propagate the `region` argument when it is not a tuple. This helps with +# inference of calls like plan_fft(X), which is translated by AbstractFFTs.jl into +# plan_fft(X, 1:ndims(X)). for f in (:plan_fft!, :plan_bfft!, :plan_fft, :plan_bfft) @eval begin - @inline function $f(X::DenseCuArray{T,N}, region) where {T<:cufftComplexes,N} + Base.@constprop :aggressive function $f(X::DenseCuArray{T,N}, region) where {T<:cufftComplexes,N} R = length(region) region = NTuple{R,Int}(region) $f(X, region) @@ -210,7 +212,7 @@ function plan_bfft(X::DenseCuArray{T,N}, region::NTuple{R,Int}) where {T<:cufftC end # out-of-place real-to-complex -@inline function plan_rfft(X::DenseCuArray{T,N}, region) where {T<:cufftReals,N} +Base.@constprop :aggressive function plan_rfft(X::DenseCuArray{T,N}, region) where {T<:cufftReals,N} R = length(region) region = NTuple{R,Int}(region) plan_rfft(X, region) @@ -237,7 +239,7 @@ function plan_rfft(X::DenseCuArray{T,N}, region::NTuple{R,Int}) where {T<:cufftR end # out-of-place complex-to-real -@inline function plan_brfft(X::DenseCuArray{T,N}, d::Integer, region) where {T<:cufftComplexes,N} +Base.@constprop :aggressive function plan_brfft(X::DenseCuArray{T,N}, d::Integer, region) where {T<:cufftComplexes,N} R = length(region) region = NTuple{R,Int}(region) plan_brfft(X, d, region)