Skip to content

Fix inference of FFT plan creation #2683

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Mar 10, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 36 additions & 24 deletions lib/cufft/fft.jl
Original file line number Diff line number Diff line change
Expand Up @@ -152,12 +152,23 @@ end
# region is an iterable subset of dimensions
# spec. an integer, range, tuple, or array

# try to constant-propagate the `region` argument when it is not a tuple. This helps with
# inference of calls like plan_fft(X), which is translated by AbstractFFTs.jl into
# plan_fft(X, 1:ndims(X)).
for f in (:plan_fft!, :plan_bfft!, :plan_fft, :plan_bfft)
@eval begin
Base.@constprop :aggressive function $f(X::DenseCuArray{T,N}, region) where {T<:cufftComplexes,N}
R = length(region)
region = NTuple{R,Int}(region)
$f(X, region)
end
end
end

# inplace complex
function plan_fft!(X::DenseCuArray{T,N}, region) where {T<:cufftComplexes,N}
function plan_fft!(X::DenseCuArray{T,N}, region::NTuple{R,Int}) where {T<:cufftComplexes,N,R}
K = CUFFT_FORWARD
inplace = true
R = length(region)
region = NTuple{R,Int}(region)

md = plan_max_dims(region, size(X))
sizex = size(X)[1:md]
Expand All @@ -166,11 +177,9 @@ function plan_fft!(X::DenseCuArray{T,N}, region) where {T<:cufftComplexes,N}
CuFFTPlan{T,T,K,inplace,N,R,Nothing}(handle, X, size(X), region, nothing)
end

function plan_bfft!(X::DenseCuArray{T,N}, region) where {T<:cufftComplexes,N}
function plan_bfft!(X::DenseCuArray{T,N}, region::NTuple{R,Int}) where {T<:cufftComplexes,N,R}
K = CUFFT_INVERSE
inplace = true
R = length(region)
region = NTuple{R,Int}(region)

md = plan_max_dims(region, size(X))
sizex = size(X)[1:md]
Expand All @@ -180,11 +189,9 @@ function plan_bfft!(X::DenseCuArray{T,N}, region) where {T<:cufftComplexes,N}
end

# out-of-place complex
function plan_fft(X::DenseCuArray{T,N}, region) where {T<:cufftComplexes,N}
function plan_fft(X::DenseCuArray{T,N}, region::NTuple{R,Int}) where {T<:cufftComplexes,N,R}
K = CUFFT_FORWARD
inplace = false
R = length(region)
region = NTuple{R,Int}(region)

md = plan_max_dims(region,size(X))
sizex = size(X)[1:md]
Expand All @@ -193,11 +200,9 @@ function plan_fft(X::DenseCuArray{T,N}, region) where {T<:cufftComplexes,N}
CuFFTPlan{T,T,K,inplace,N,R,Nothing}(handle, X, size(X), region, nothing)
end

function plan_bfft(X::DenseCuArray{T,N}, region) where {T<:cufftComplexes,N}
function plan_bfft(X::DenseCuArray{T,N}, region::NTuple{R,Int}) where {T<:cufftComplexes,N,R}
K = CUFFT_INVERSE
inplace = false
R = length(region)
region = NTuple{R,Int}(region)

md = plan_max_dims(region,size(X))
sizex = size(X)[1:md]
Expand All @@ -207,19 +212,23 @@ function plan_bfft(X::DenseCuArray{T,N}, region) where {T<:cufftComplexes,N}
end

# out-of-place real-to-complex
function plan_rfft(X::DenseCuArray{T,N}, region) where {T<:cufftReals,N}
K = CUFFT_FORWARD
inplace = false
Base.@constprop :aggressive function plan_rfft(X::DenseCuArray{T,N}, region) where {T<:cufftReals,N}
R = length(region)
region = NTuple{R,Int}(region)
plan_rfft(X, region)
end

function plan_rfft(X::DenseCuArray{T,N}, region::NTuple{R,Int}) where {T<:cufftReals,N,R}
K = CUFFT_FORWARD
inplace = false

md = plan_max_dims(region,size(X))
sizex = size(X)[1:md]

handle = cufftGetPlan(complex(T), T, sizex, region)

ydims = collect(size(X))
ydims[region[1]] = div(ydims[region[1]], 2) + 1
xdims = size(X)
ydims = Base.setindex(xdims, div(xdims[region[1]], 2) + 1, region[1])

# The buffer is not needed for real-to-complex (`mul!`),
# but it’s required for complex-to-real (`ldiv!`).
Expand All @@ -230,21 +239,24 @@ function plan_rfft(X::DenseCuArray{T,N}, region) where {T<:cufftReals,N}
end

# out-of-place complex-to-real
function plan_brfft(X::DenseCuArray{T,N}, d::Integer, region) where {T<:cufftComplexes,N}
K = CUFFT_INVERSE
inplace = false
Base.@constprop :aggressive function plan_brfft(X::DenseCuArray{T,N}, d::Integer, region) where {T<:cufftComplexes,N}
R = length(region)
region = NTuple{R,Int}(region)
plan_brfft(X, d, region)
end

ydims = collect(size(X))
ydims[region[1]] = d
function plan_brfft(X::DenseCuArray{T,N}, d::Integer, region::NTuple{R,Int}) where {T<:cufftComplexes,N,R}
K = CUFFT_INVERSE
inplace = false

handle = cufftGetPlan(real(T), T, (ydims...,), region)
xdims = size(X)
ydims = Base.setindex(xdims, d, region[1])
handle = cufftGetPlan(real(T), T, ydims, region)

buffer = CuArray{T}(undef, size(X))
B = typeof(buffer)

CuFFTPlan{real(T),T,K,inplace,N,R,B}(handle, size(X), (ydims...,), region, buffer)
CuFFTPlan{real(T),T,K,inplace,N,R,B}(handle, size(X), ydims, region, buffer)
end


Expand Down
21 changes: 18 additions & 3 deletions test/libraries/cufft.jl
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ atol(::Type{Complex{T}}) where {T} = atol(T)
function out_of_place(X::AbstractArray{T,N}) where {T <: Complex,N}
fftw_X = fft(X)
d_X = CuArray(X)
p = plan_fft(d_X)
p = @inferred plan_fft(d_X)
d_Y = p * d_X
Y = collect(d_Y)
@test isapprox(Y, fftw_X, rtol = rtol(T), atol = atol(T))
Expand All @@ -130,12 +130,16 @@ function out_of_place(X::AbstractArray{T,N}) where {T <: Complex,N}
Z = collect(d_Z)
@test isapprox(Z, X, rtol = rtol(T), atol = atol(T))

pinvb = @inferred plan_bfft(d_Y)
d_Z = pinvb * d_Y
Z = collect(d_Z) ./ length(d_Z)
@test isapprox(Z, X, rtol = rtol(T), atol = atol(T))
end

function in_place(X::AbstractArray{T,N}) where {T <: Complex,N}
fftw_X = fft(X)
d_X = CuArray(X)
p = plan_fft!(d_X)
p = @inferred plan_fft!(d_X)
p * d_X
Y = collect(d_X)
@test isapprox(Y, fftw_X, rtol = rtol(T), atol = atol(T))
Expand All @@ -144,6 +148,12 @@ function in_place(X::AbstractArray{T,N}) where {T <: Complex,N}
pinv * d_X
Z = collect(d_X)
@test isapprox(Z, X, rtol = rtol(T), atol = atol(T))
p * d_X

pinvb = @inferred plan_bfft!(d_X)
pinvb * d_X
Z = collect(d_X) ./ length(X)
@test isapprox(Z, X, rtol = rtol(T), atol = atol(T))
end

function batched(X::AbstractArray{T,N},region) where {T <: Complex,N}
Expand Down Expand Up @@ -261,7 +271,7 @@ end
function out_of_place(X::AbstractArray{T,N}) where {T <: Real,N}
fftw_X = rfft(X)
d_X = CuArray(X)
p = plan_rfft(d_X)
p = @inferred plan_rfft(d_X)
d_Y = p * d_X
Y = collect(d_Y)
@test isapprox(Y, fftw_X, rtol = rtol(T), atol = atol(T))
Expand All @@ -280,6 +290,11 @@ function out_of_place(X::AbstractArray{T,N}) where {T <: Real,N}
d_W = pinv3 * d_X
W = collect(d_W)
@test isapprox(W, Y, rtol = rtol(T), atol = atol(T))

pinvb = @inferred plan_brfft(d_Y,size(X,1))
d_Z = pinvb * d_Y
Z = collect(d_Z) ./ length(X)
@test isapprox(Z, X, rtol = rtol(T), atol = atol(T))
end

function batched(X::AbstractArray{T,N},region) where {T <: Real,N}
Expand Down