Skip to content

Commit 6fdca86

Browse files
authored
Fix inference of FFT plan creation (#2683)
1 parent 9455b65 commit 6fdca86

File tree

2 files changed

+54
-27
lines changed

2 files changed

+54
-27
lines changed

lib/cufft/fft.jl

Lines changed: 36 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -152,12 +152,23 @@ end
152152
# region is an iterable subset of dimensions
153153
# spec. an integer, range, tuple, or array
154154

155+
# try to constant-propagate the `region` argument when it is not a tuple. This helps with
156+
# inference of calls like plan_fft(X), which is translated by AbstractFFTs.jl into
157+
# plan_fft(X, 1:ndims(X)).
158+
for f in (:plan_fft!, :plan_bfft!, :plan_fft, :plan_bfft)
159+
@eval begin
160+
Base.@constprop :aggressive function $f(X::DenseCuArray{T,N}, region) where {T<:cufftComplexes,N}
161+
R = length(region)
162+
region = NTuple{R,Int}(region)
163+
$f(X, region)
164+
end
165+
end
166+
end
167+
155168
# inplace complex
156-
function plan_fft!(X::DenseCuArray{T,N}, region) where {T<:cufftComplexes,N}
169+
function plan_fft!(X::DenseCuArray{T,N}, region::NTuple{R,Int}) where {T<:cufftComplexes,N,R}
157170
K = CUFFT_FORWARD
158171
inplace = true
159-
R = length(region)
160-
region = NTuple{R,Int}(region)
161172

162173
md = plan_max_dims(region, size(X))
163174
sizex = size(X)[1:md]
@@ -166,11 +177,9 @@ function plan_fft!(X::DenseCuArray{T,N}, region) where {T<:cufftComplexes,N}
166177
CuFFTPlan{T,T,K,inplace,N,R,Nothing}(handle, X, size(X), region, nothing)
167178
end
168179

169-
function plan_bfft!(X::DenseCuArray{T,N}, region) where {T<:cufftComplexes,N}
180+
function plan_bfft!(X::DenseCuArray{T,N}, region::NTuple{R,Int}) where {T<:cufftComplexes,N,R}
170181
K = CUFFT_INVERSE
171182
inplace = true
172-
R = length(region)
173-
region = NTuple{R,Int}(region)
174183

175184
md = plan_max_dims(region, size(X))
176185
sizex = size(X)[1:md]
@@ -180,11 +189,9 @@ function plan_bfft!(X::DenseCuArray{T,N}, region) where {T<:cufftComplexes,N}
180189
end
181190

182191
# out-of-place complex
183-
function plan_fft(X::DenseCuArray{T,N}, region) where {T<:cufftComplexes,N}
192+
function plan_fft(X::DenseCuArray{T,N}, region::NTuple{R,Int}) where {T<:cufftComplexes,N,R}
184193
K = CUFFT_FORWARD
185194
inplace = false
186-
R = length(region)
187-
region = NTuple{R,Int}(region)
188195

189196
md = plan_max_dims(region,size(X))
190197
sizex = size(X)[1:md]
@@ -193,11 +200,9 @@ function plan_fft(X::DenseCuArray{T,N}, region) where {T<:cufftComplexes,N}
193200
CuFFTPlan{T,T,K,inplace,N,R,Nothing}(handle, X, size(X), region, nothing)
194201
end
195202

196-
function plan_bfft(X::DenseCuArray{T,N}, region) where {T<:cufftComplexes,N}
203+
function plan_bfft(X::DenseCuArray{T,N}, region::NTuple{R,Int}) where {T<:cufftComplexes,N,R}
197204
K = CUFFT_INVERSE
198205
inplace = false
199-
R = length(region)
200-
region = NTuple{R,Int}(region)
201206

202207
md = plan_max_dims(region,size(X))
203208
sizex = size(X)[1:md]
@@ -207,19 +212,23 @@ function plan_bfft(X::DenseCuArray{T,N}, region) where {T<:cufftComplexes,N}
207212
end
208213

209214
# out-of-place real-to-complex
210-
function plan_rfft(X::DenseCuArray{T,N}, region) where {T<:cufftReals,N}
211-
K = CUFFT_FORWARD
212-
inplace = false
215+
Base.@constprop :aggressive function plan_rfft(X::DenseCuArray{T,N}, region) where {T<:cufftReals,N}
213216
R = length(region)
214217
region = NTuple{R,Int}(region)
218+
plan_rfft(X, region)
219+
end
220+
221+
function plan_rfft(X::DenseCuArray{T,N}, region::NTuple{R,Int}) where {T<:cufftReals,N,R}
222+
K = CUFFT_FORWARD
223+
inplace = false
215224

216225
md = plan_max_dims(region,size(X))
217226
sizex = size(X)[1:md]
218227

219228
handle = cufftGetPlan(complex(T), T, sizex, region)
220229

221-
ydims = collect(size(X))
222-
ydims[region[1]] = div(ydims[region[1]], 2) + 1
230+
xdims = size(X)
231+
ydims = Base.setindex(xdims, div(xdims[region[1]], 2) + 1, region[1])
223232

224233
# The buffer is not needed for real-to-complex (`mul!`),
225234
# but it’s required for complex-to-real (`ldiv!`).
@@ -230,21 +239,24 @@ function plan_rfft(X::DenseCuArray{T,N}, region) where {T<:cufftReals,N}
230239
end
231240

232241
# out-of-place complex-to-real
233-
function plan_brfft(X::DenseCuArray{T,N}, d::Integer, region) where {T<:cufftComplexes,N}
234-
K = CUFFT_INVERSE
235-
inplace = false
242+
Base.@constprop :aggressive function plan_brfft(X::DenseCuArray{T,N}, d::Integer, region) where {T<:cufftComplexes,N}
236243
R = length(region)
237244
region = NTuple{R,Int}(region)
245+
plan_brfft(X, d, region)
246+
end
238247

239-
ydims = collect(size(X))
240-
ydims[region[1]] = d
248+
function plan_brfft(X::DenseCuArray{T,N}, d::Integer, region::NTuple{R,Int}) where {T<:cufftComplexes,N,R}
249+
K = CUFFT_INVERSE
250+
inplace = false
241251

242-
handle = cufftGetPlan(real(T), T, (ydims...,), region)
252+
xdims = size(X)
253+
ydims = Base.setindex(xdims, d, region[1])
254+
handle = cufftGetPlan(real(T), T, ydims, region)
243255

244256
buffer = CuArray{T}(undef, size(X))
245257
B = typeof(buffer)
246258

247-
CuFFTPlan{real(T),T,K,inplace,N,R,B}(handle, size(X), (ydims...,), region, buffer)
259+
CuFFTPlan{real(T),T,K,inplace,N,R,B}(handle, size(X), ydims, region, buffer)
248260
end
249261

250262

test/libraries/cufft.jl

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ atol(::Type{Complex{T}}) where {T} = atol(T)
115115
function out_of_place(X::AbstractArray{T,N}) where {T <: Complex,N}
116116
fftw_X = fft(X)
117117
d_X = CuArray(X)
118-
p = plan_fft(d_X)
118+
p = @inferred plan_fft(d_X)
119119
d_Y = p * d_X
120120
Y = collect(d_Y)
121121
@test isapprox(Y, fftw_X, rtol = rtol(T), atol = atol(T))
@@ -130,12 +130,16 @@ function out_of_place(X::AbstractArray{T,N}) where {T <: Complex,N}
130130
Z = collect(d_Z)
131131
@test isapprox(Z, X, rtol = rtol(T), atol = atol(T))
132132

133+
pinvb = @inferred plan_bfft(d_Y)
134+
d_Z = pinvb * d_Y
135+
Z = collect(d_Z) ./ length(d_Z)
136+
@test isapprox(Z, X, rtol = rtol(T), atol = atol(T))
133137
end
134138

135139
function in_place(X::AbstractArray{T,N}) where {T <: Complex,N}
136140
fftw_X = fft(X)
137141
d_X = CuArray(X)
138-
p = plan_fft!(d_X)
142+
p = @inferred plan_fft!(d_X)
139143
p * d_X
140144
Y = collect(d_X)
141145
@test isapprox(Y, fftw_X, rtol = rtol(T), atol = atol(T))
@@ -144,6 +148,12 @@ function in_place(X::AbstractArray{T,N}) where {T <: Complex,N}
144148
pinv * d_X
145149
Z = collect(d_X)
146150
@test isapprox(Z, X, rtol = rtol(T), atol = atol(T))
151+
p * d_X
152+
153+
pinvb = @inferred plan_bfft!(d_X)
154+
pinvb * d_X
155+
Z = collect(d_X) ./ length(X)
156+
@test isapprox(Z, X, rtol = rtol(T), atol = atol(T))
147157
end
148158

149159
function batched(X::AbstractArray{T,N},region) where {T <: Complex,N}
@@ -261,7 +271,7 @@ end
261271
function out_of_place(X::AbstractArray{T,N}) where {T <: Real,N}
262272
fftw_X = rfft(X)
263273
d_X = CuArray(X)
264-
p = plan_rfft(d_X)
274+
p = @inferred plan_rfft(d_X)
265275
d_Y = p * d_X
266276
Y = collect(d_Y)
267277
@test isapprox(Y, fftw_X, rtol = rtol(T), atol = atol(T))
@@ -280,6 +290,11 @@ function out_of_place(X::AbstractArray{T,N}) where {T <: Real,N}
280290
d_W = pinv3 * d_X
281291
W = collect(d_W)
282292
@test isapprox(W, Y, rtol = rtol(T), atol = atol(T))
293+
294+
pinvb = @inferred plan_brfft(d_Y,size(X,1))
295+
d_Z = pinvb * d_Y
296+
Z = collect(d_Z) ./ length(X)
297+
@test isapprox(Z, X, rtol = rtol(T), atol = atol(T))
283298
end
284299

285300
function batched(X::AbstractArray{T,N},region) where {T <: Real,N}

0 commit comments

Comments
 (0)