Skip to content

Commit 19a08ef

Browse files
amontoisonmaleadt
andauthored
[CUFFT] Preallocate a buffer for complex-to-real FFT (#2578)
* [CUFFT] Preallocate a buffer for complex-to-real FFT * Update cufft.jl * Fix new errors in fft.jl * More fixes in fft.jl * Allocate a buffer in both plan_rfft and plan_brfft * Allocate a buffer in both plan_rfft and plan_brfft * Update lib/cufft/fft.jl Co-authored-by: Tim Besard <tim.besard@gmail.com> --------- Co-authored-by: Tim Besard <tim.besard@gmail.com>
1 parent ca8f6cf commit 19a08ef

File tree

1 file changed

+71
-41
lines changed

1 file changed

+71
-41
lines changed

lib/cufft/fft.jl

Lines changed: 71 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -25,32 +25,34 @@ Base.:(*)(p::ScaledPlan, x::DenseCuArray) = rmul!(p.p * x, p.scale)
2525

2626
# N is the number of dimensions
2727

28-
mutable struct CuFFTPlan{T<:cufftNumber,S<:cufftNumber,K,inplace,N} <: Plan{S}
28+
mutable struct CuFFTPlan{T<:cufftNumber,S<:cufftNumber,K,inplace,N,R,B} <: Plan{S}
2929
# handle to Cuda low level plan. Note that this plan sometimes has lower dimensions
3030
# to handle more transform cases such as individual directions
3131
handle::cufftHandle
3232
ctx::CuContext
3333
stream::CuStream
3434
input_size::NTuple{N,Int} # Julia size of input array
3535
output_size::NTuple{N,Int} # Julia size of output array
36-
region::Any
36+
region::NTuple{R,Int}
37+
buffer::B # buffer for out-of-place complex-to-real FFT, or `nothing` if not needed
3738
pinv::ScaledPlan{T} # required by AbstractFFTs API, will be defined by AbstractFFTs if needed
3839

39-
function CuFFTPlan{T,S,K,inplace,N}(handle::cufftHandle,
40-
input_size::NTuple{N,Int}, output_size::NTuple{N,Int}, region
41-
) where {T<:cufftNumber,S<:cufftNumber,K,inplace,N}
40+
function CuFFTPlan{T,S,K,inplace,N,R,B}(handle::cufftHandle,
41+
input_size::NTuple{N,Int}, output_size::NTuple{N,Int},
42+
region::NTuple{R,Int}, buffer::B
43+
) where {T<:cufftNumber,S<:cufftNumber,K,inplace,N,R,B}
4244
abs(K) == 1 || throw(ArgumentError("FFT direction must be either -1 (forward) or +1 (inverse)"))
4345
inplace isa Bool || throw(ArgumentError("FFT inplace argument must be a Bool"))
44-
p = new{T,S,K,inplace,N}(handle, context(), stream(), input_size, output_size, region)
46+
p = new{T,S,K,inplace,N,R,B}(handle, context(), stream(), input_size, output_size, region, buffer)
4547
finalizer(unsafe_free!, p)
4648
p
4749
end
4850
end
4951

50-
function CuFFTPlan{T,S,K,inplace,N}(handle::cufftHandle, X::DenseCuArray{S,N},
51-
sizey::NTuple{N,Int}, region,
52-
) where {T<:cufftNumber,S<:cufftNumber,K,inplace,N}
53-
CuFFTPlan{T,S,K,inplace,N}(handle, size(X), sizey, region)
52+
function CuFFTPlan{T,S,K,inplace,N,R,B}(handle::cufftHandle, X::DenseCuArray{S,N},
53+
sizey::NTuple{N,Int}, region::NTuple{R,Int}, buffer::B
54+
) where {T<:cufftNumber,S<:cufftNumber,K,inplace,N,R,B}
55+
CuFFTPlan{T,S,K,inplace,N,R,B}(handle, size(X), sizey, region, buffer)
5456
end
5557

5658
function CUDA.unsafe_free!(plan::CuFFTPlan)
@@ -60,6 +62,9 @@ function CUDA.unsafe_free!(plan::CuFFTPlan)
6062
end
6163
plan.handle = C_NULL
6264
end
65+
if !isnothing(plan.buffer)
66+
CUDA.unsafe_free!(plan.buffer)
67+
end
6368
end
6469

6570
function showfftdims(io, sz, T)
@@ -151,103 +156,116 @@ end
151156
function plan_fft!(X::DenseCuArray{T,N}, region) where {T<:cufftComplexes,N}
152157
K = CUFFT_FORWARD
153158
inplace = true
154-
region = Tuple(region)
159+
R = length(region)
160+
region = NTuple{R,Int}(region)
155161

156162
md = plan_max_dims(region, size(X))
157163
sizex = size(X)[1:md]
158164
handle = cufftGetPlan(T, T, sizex, region)
159165

160-
CuFFTPlan{T,T,K,inplace,N}(handle, X, size(X), region)
166+
CuFFTPlan{T,T,K,inplace,N,R,Nothing}(handle, X, size(X), region, nothing)
161167
end
162168

163-
164169
function plan_bfft!(X::DenseCuArray{T,N}, region) where {T<:cufftComplexes,N}
165170
K = CUFFT_INVERSE
166171
inplace = true
167-
region = Tuple(region)
172+
R = length(region)
173+
region = NTuple{R,Int}(region)
168174

169175
md = plan_max_dims(region, size(X))
170176
sizex = size(X)[1:md]
171177
handle = cufftGetPlan(T, T, sizex, region)
172178

173-
CuFFTPlan{T,T,K,inplace,N}(handle, X, size(X), region)
179+
CuFFTPlan{T,T,K,inplace,N,R,Nothing}(handle, X, size(X), region, nothing)
174180
end
175181

176182
# out-of-place complex
177183
function plan_fft(X::DenseCuArray{T,N}, region) where {T<:cufftComplexes,N}
178184
K = CUFFT_FORWARD
179185
inplace = false
180-
region = Tuple(region)
186+
R = length(region)
187+
region = NTuple{R,Int}(region)
181188

182189
md = plan_max_dims(region,size(X))
183190
sizex = size(X)[1:md]
184191
handle = cufftGetPlan(T, T, sizex, region)
185192

186-
CuFFTPlan{T,T,K,inplace,N}(handle, X, size(X), region)
193+
CuFFTPlan{T,T,K,inplace,N,R,Nothing}(handle, X, size(X), region, nothing)
187194
end
188195

189196
function plan_bfft(X::DenseCuArray{T,N}, region) where {T<:cufftComplexes,N}
190197
K = CUFFT_INVERSE
191198
inplace = false
192-
region = Tuple(region)
199+
R = length(region)
200+
region = NTuple{R,Int}(region)
193201

194202
md = plan_max_dims(region,size(X))
195203
sizex = size(X)[1:md]
196204
handle = cufftGetPlan(T, T, sizex, region)
197205

198-
CuFFTPlan{T,T,K,inplace,N}(handle, size(X), size(X), region)
206+
CuFFTPlan{T,T,K,inplace,N,R,Nothing}(handle, size(X), size(X), region, nothing)
199207
end
200208

201209
# out-of-place real-to-complex
202210
function plan_rfft(X::DenseCuArray{T,N}, region) where {T<:cufftReals,N}
203211
K = CUFFT_FORWARD
204212
inplace = false
205-
region = Tuple(region)
213+
R = length(region)
214+
region = NTuple{R,Int}(region)
206215

207216
md = plan_max_dims(region,size(X))
208-
# X = front_view(X, md)
209217
sizex = size(X)[1:md]
210218

211219
handle = cufftGetPlan(complex(T), T, sizex, region)
212220

213221
ydims = collect(size(X))
214-
ydims[region[1]] = div(ydims[region[1]],2)+1
222+
ydims[region[1]] = div(ydims[region[1]], 2) + 1
215223

216-
CuFFTPlan{complex(T),T,K,inplace,N}(handle, size(X), (ydims...,), region)
224+
# The buffer is not needed for real-to-complex (`mul!`),
225+
# but it’s required for complex-to-real (`ldiv!`).
226+
buffer = CuArray{complex(T)}(undef, ydims...)
227+
B = typeof(buffer)
228+
229+
CuFFTPlan{complex(T),T,K,inplace,N,R,B}(handle, size(X), (ydims...,), region, buffer)
217230
end
218231

219-
function plan_brfft(X::DenseCuArray{T,N}, d::Integer, region::Any) where {T<:cufftComplexes,N}
232+
# out-of-place complex-to-real
233+
function plan_brfft(X::DenseCuArray{T,N}, d::Integer, region) where {T<:cufftComplexes,N}
220234
K = CUFFT_INVERSE
221235
inplace = false
222-
region = Tuple(region)
236+
R = length(region)
237+
region = NTuple{R,Int}(region)
223238

224239
ydims = collect(size(X))
225240
ydims[region[1]] = d
226241

227242
handle = cufftGetPlan(real(T), T, (ydims...,), region)
228243

229-
CuFFTPlan{real(T),T,K,inplace,N}(handle, size(X), (ydims...,), region)
244+
buffer = CuArray{T}(undef, size(X))
245+
B = typeof(buffer)
246+
247+
CuFFTPlan{real(T),T,K,inplace,N,R,B}(handle, size(X), (ydims...,), region, buffer)
230248
end
231249

232250

233251
# FIXME: plan_inv methods allocate needlessly (to provide type parameters)
234252
# Perhaps use FakeArray types to avoid this.
235253

236-
function plan_inv(p::CuFFTPlan{T,S,CUFFT_INVERSE,inplace,N}
237-
) where {T<:cufftNumber,S<:cufftNumber,N,inplace}
254+
function plan_inv(p::CuFFTPlan{T,S,CUFFT_INVERSE,inplace,N,R,B}
255+
) where {T<:cufftNumber,S<:cufftNumber,inplace,N,R,B}
238256
md_osz = plan_max_dims(p.region, p.output_size)
239257
sz_X = p.output_size[1:md_osz]
240258
handle = cufftGetPlan(S, T, sz_X, p.region)
241-
ScaledPlan(CuFFTPlan{S,T,CUFFT_FORWARD,inplace,N}(handle, p.output_size, p.input_size, p.region),
259+
ScaledPlan(CuFFTPlan{S,T,CUFFT_FORWARD,inplace,N,R,B}(handle, p.output_size, p.input_size, p.region, p.buffer),
242260
normalization(real(T), p.output_size, p.region))
243261
end
244262

245-
function plan_inv(p::CuFFTPlan{T,S,CUFFT_FORWARD,inplace,N}
246-
) where {T<:cufftNumber,S<:cufftNumber,N,inplace}
263+
function plan_inv(p::CuFFTPlan{T,S,CUFFT_FORWARD,inplace,N,R,B}
264+
) where {T<:cufftNumber,S<:cufftNumber,inplace,N,R,B}
247265
md_isz = plan_max_dims(p.region, p.input_size)
248266
sz_Y = p.input_size[1:md_isz]
249267
handle = cufftGetPlan(S, T, sz_Y, p.region)
250-
ScaledPlan(CuFFTPlan{S,T,CUFFT_INVERSE,inplace,N}(handle, p.output_size, p.input_size, p.region),
268+
ScaledPlan(CuFFTPlan{S,T,CUFFT_INVERSE,inplace,N,R,B}(handle, p.output_size, p.input_size, p.region, p.buffer),
251269
normalization(real(S), p.input_size, p.region))
252270
end
253271

@@ -309,10 +327,14 @@ function LinearAlgebra.mul!(y::DenseCuArray{T}, p::CuFFTPlan{T,S,K,inplace}, x::
309327
) where {T,S,K,inplace}
310328
assert_applicable(p, x, y)
311329
if !inplace && T<:Real
312-
# Out-of-place complex-to-real FFT will always overwrite input buffer.
313-
x = copy(x)
330+
# Out-of-place complex-to-real FFT will always overwrite input x.
331+
# We copy the input x in an auxiliary buffer.
332+
z = p.buffer
333+
copyto!(z, x)
334+
else
335+
z = x
314336
end
315-
unsafe_execute_trailing!(p, x, y)
337+
unsafe_execute_trailing!(p, z, y)
316338
y
317339
end
318340

@@ -323,13 +345,21 @@ function Base.:(*)(p::CuFFTPlan{T,S,K,true}, x::DenseCuArray{S}) where {T,S,K}
323345
end
324346

325347
function Base.:(*)(p::CuFFTPlan{T,S,K,false}, x::DenseCuArray{S1,M}) where {T,S,K,S1,M}
326-
if S1 != S || T<:Real
327-
# Convert to the expected input type. Also,
328-
# Out-of-place complex-to-real FFT will always overwrite input buffer.
329-
x = copy1(S, x)
348+
if T<:Real
349+
# Out-of-place complex-to-real FFT will always overwrite input x.
350+
# We copy the input x in an auxiliary buffer.
351+
z = p.buffer
352+
copyto!(z, x)
353+
else
354+
if S1 != S
355+
# Convert to the expected input type.
356+
z = copy1(S, x)
357+
else
358+
z = x
359+
end
330360
end
331-
assert_applicable(p, x)
361+
assert_applicable(p, z)
332362
y = CuArray{T,M}(undef, p.output_size)
333-
unsafe_execute_trailing!(p, x, y)
363+
unsafe_execute_trailing!(p, z, y)
334364
y
335365
end

0 commit comments

Comments
 (0)