Skip to content

Fix inference of FFT plan creation #739

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Mar 10, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 37 additions & 21 deletions src/fft/fft.jl
Original file line number Diff line number Diff line change
Expand Up @@ -136,47 +136,63 @@ for (f, xtype, inplace, forward) in (
(:plan_fft, :rocfft_transform_type_complex_forward, :false, :true),
(:plan_bfft, :rocfft_transform_type_complex_inverse, :false, :false),
)
@eval function $f(X::ROCArray{T, N}, region) where {T <: rocfftComplexes, N}
_inplace = $(inplace)
_xtype = $(xtype)
R = length(region)
region = NTuple{R,Int}(region)
pp = get_plan(_xtype, size(X), T, _inplace, region)
return cROCFFTPlan{T,$forward,_inplace,N,R,Nothing}(pp..., X, size(X), _xtype, region, nothing, false, T)
@eval begin
# Try to constant-propagate the `region` argument so that its length `R` can be inferred.
Base.@constprop :aggressive function $f(X::ROCArray{T, N}, region) where {T <: rocfftComplexes, N}
R = length(region)
region = NTuple{R,Int}(region)
return $f(X, region)
end

function $f(X::ROCArray{T, N}, region::NTuple{R,Int}) where {T <: rocfftComplexes, N, R}
_inplace = $(inplace)
_xtype = $(xtype)
pp = get_plan(_xtype, size(X), T, _inplace, region)
return cROCFFTPlan{T,$forward,_inplace,N,R,Nothing}(pp..., X, size(X), _xtype, region, nothing, false, T)
end
end
end

function plan_rfft(X::ROCArray{T,N}, region) where {T<:rocfftReals,N}
inplace = false
xtype = rocfft_transform_type_real_forward
Base.@constprop :aggressive function plan_rfft(X::ROCArray{T,N}, region) where {T<:rocfftReals,N}
R = length(region)
region = NTuple{R,Int}(region)
return plan_rfft(X, region)
end

function plan_rfft(X::ROCArray{T,N}, region::NTuple{R,Int}) where {T<:rocfftReals,N,R}
inplace = false
xtype = rocfft_transform_type_real_forward
pp = get_plan(xtype, size(X), T, inplace, region)
ydims = collect(size(X))
ydims[region[1]] = div(ydims[region[1]],2) + 1

xdims = size(X)
ydims = Base.setindex(xdims, div(xdims[region[1]],2) + 1, region[1])

# The buffer is not needed for real-to-complex (`mul!`),
# but it’s required for complex-to-real (`ldiv!`).
buffer = ROCArray{complex(T)}(undef, ydims...)
buffer = ROCArray{complex(T)}(undef, ydims)
B = typeof(buffer)

return rROCFFTPlan{T,ROCFFT_FORWARD,inplace,N,R,B}(pp..., X, (ydims...,), xtype, region, buffer, true, T)
return rROCFFTPlan{T,ROCFFT_FORWARD,inplace,N,R,B}(pp..., X, ydims, xtype, region, buffer, true, T)
end

function plan_brfft(X::ROCArray{T,N}, d::Integer, region) where {T <: rocfftComplexes, N}
inplace = false
xtype = rocfft_transform_type_real_inverse
Base.@constprop :aggressive function plan_brfft(X::ROCArray{T,N}, d::Integer, region) where {T <: rocfftComplexes, N}
R = length(region)
region = NTuple{R,Int}(region)
ydims = collect(size(X))
ydims[region[1]] = d
pp = get_plan(xtype, (ydims...,), T, inplace, region)
return plan_brfft(X, d, region)
end

function plan_brfft(X::ROCArray{T,N}, d::Integer, region::NTuple{R,Int}) where {T <: rocfftComplexes, N, R}
inplace = false
xtype = rocfft_transform_type_real_inverse
xdims = size(X)
ydims = Base.setindex(xdims, d, region[1])
pp = get_plan(xtype, ydims, T, inplace, region)

# Buffer to not modify the input in a complex-to-real FFT.
buffer = ROCArray{T}(undef, size(X))
B = typeof(buffer)

return rROCFFTPlan{T,ROCFFT_INVERSE,inplace,N,R,B}(pp..., X, (ydims...,), xtype, region, buffer, false, T)
return rROCFFTPlan{T,ROCFFT_INVERSE,inplace,N,R,B}(pp..., X, ydims, xtype, region, buffer, false, T)
end

# FIXME: plan_inv methods allocate needlessly (to provide type parameters and normalization function)
Expand Down
10 changes: 5 additions & 5 deletions test/rocarray/fft.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ function out_of_place(X::AbstractArray{T,N}) where {T <: Complex,N}
fftw_X = fft(X)

dX = ROCArray(X)
p = plan_fft(dX)
p = @inferred plan_fft(dX)
dY = p * dX
@test isapprox(collect(dY), fftw_X; rtol=MYRTOL, atol=MYATOL)
@test X ≈ collect(dX)
Expand All @@ -37,7 +37,7 @@ function in_place(X::AbstractArray{T,N}) where {T <: Complex,N}
fftw_X = fft(X)

dX = ROCArray(X)
p = plan_fft!(dX)
p = @inferred plan_fft!(dX)
p * dX
@test isapprox(collect(dX), fftw_X; rtol=MYRTOL, atol=MYATOL)

Expand All @@ -50,7 +50,7 @@ function batched(X::AbstractArray{T,N}, region) where {T <: Complex,N}
fftw_X = fft(X, region)

dX = ROCArray(X)
p = plan_fft!(dX, region)
p = @inferred plan_fft!(dX, region)
p * dX
@test isapprox(collect(dX), fftw_X; rtol=MYRTOL, atol=MYATOL)

Expand Down Expand Up @@ -173,7 +173,7 @@ function out_of_place(X::AbstractArray{T,N}) where {T <: Real,N}
fftw_X = rfft(X)
dX = ROCArray(X)

p = plan_rfft(dX)
p = @inferred plan_rfft(dX)
dY = p * dX
Y = collect(dY)
@test isapprox(Y, fftw_X; rtol=MYRTOL, atol=MYATOL)
Expand All @@ -197,7 +197,7 @@ function batched(X::AbstractArray{T,N},region) where {T <: Real,N}
fftw_X = rfft(X,region)
dX = ROCArray(X)

p = plan_rfft(dX, region)
p = @inferred plan_rfft(dX, region)
dY = p * dX
@test isapprox(collect(dY), fftw_X; rtol=MYRTOL, atol=MYATOL)
@test X ≈ collect(dX)
Expand Down