Skip to content
This repository was archived by the owner on Mar 12, 2021. It is now read-only.

Commit 2106b08

Browse files
authored
Merge pull request #436 from JuliaGPU/tb/cufft_workarea
Use the memory pool for CUFFT workarea allocation.
2 parents b8b1c4e + 80603e2 commit 2106b08

File tree

2 files changed

+66
-52
lines changed

2 files changed

+66
-52
lines changed

src/fft/CUFFT.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ using CUDAapi
44

55
using ..CuArrays
66
using ..CuArrays: libcufft
7+
import ..CuArrays: unsafe_free!
78

89
using CUDAdrv
910
using CUDAdrv: CuStream_t

src/fft/fft.jl

Lines changed: 65 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -17,46 +17,54 @@ Base.:(*)(p::ScaledPlan, x::CuArray) = rmul!(p.p * x, p.scale)
1717

1818
abstract type CuFFTPlan{T<:cufftNumber, K, inplace} <: Plan{T} end
1919

20+
Base.unsafe_convert(::Type{cufftHandle}, p::CuFFTPlan) = p.handle
21+
22+
# for some reason, cufftHandle is an integer and not a pointer...
23+
Base.convert(::Type{cufftHandle}, p::CuFFTPlan) = Base.unsafe_convert(cufftHandle, p)
24+
25+
function unsafe_free!(plan::CuFFTPlan)
26+
cufftDestroy(plan)
27+
unsafe_free!(plan.workarea)
28+
end
29+
2030
mutable struct cCuFFTPlan{T<:cufftNumber,K,inplace,N} <: CuFFTPlan{T,K,inplace}
21-
plan::cufftHandle
31+
handle::cufftHandle
32+
workarea::CuVector{Int8}
2233
sz::NTuple{N,Int} # Julia size of input array
2334
osz::NTuple{N,Int} # Julia size of output array
2435
xtype::cufftType
2536
region::Any
2637
pinv::ScaledPlan # required by AbstractFFT API
2738

28-
function cCuFFTPlan{T,K,inplace,N}(plan::cufftHandle, X::CuArray{T,N},
29-
sizey::Tuple, region, xtype
39+
function cCuFFTPlan{T,K,inplace,N}(handle::cufftHandle, workarea::CuVector{Int8},
40+
X::CuArray{T,N}, sizey::Tuple, region, xtype
3041
) where {T<:cufftNumber,K,inplace,N}
3142
# maybe enforce consistency of sizey
32-
p = new(plan, size(X), sizey, xtype, region)
33-
finalizer(destroy_plan, p)
43+
p = new(handle, workarea, size(X), sizey, xtype, region)
44+
finalizer(unsafe_free!, p)
3445
p
3546
end
3647
end
3748

38-
cCuFFTPlan(plan,X,region,xtype) = cCuFFTPlan(plan,X,size(X),region,xtype)
39-
4049
mutable struct rCuFFTPlan{T<:cufftNumber,K,inplace,N} <: CuFFTPlan{T,K,inplace}
41-
plan::cufftHandle
50+
handle::cufftHandle
51+
workarea::CuVector{Int8}
4252
sz::NTuple{N,Int} # Julia size of input array
4353
osz::NTuple{N,Int} # Julia size of output array
4454
xtype::cufftType
4555
region::Any
4656
pinv::ScaledPlan # required by AbstractFFT API
4757

48-
function rCuFFTPlan{T,K,inplace,N}(plan::cufftHandle, X::CuArray{T,N},
49-
sizey::Tuple, region, xtype
58+
function rCuFFTPlan{T,K,inplace,N}(handle::cufftHandle, workarea::CuVector{Int8},
59+
X::CuArray{T,N}, sizey::Tuple, region, xtype
5060
) where {T<:cufftNumber,K,inplace,N}
5161
# maybe enforce consistency of sizey
52-
p = new(plan, size(X), sizey, xtype, region)
53-
finalizer(destroy_plan, p)
62+
p = new(handle, workarea, size(X), sizey, xtype, region)
63+
finalizer(unsafe_free!, p)
5464
p
5565
end
5666
end
5767

58-
rCuFFTPlan(plan,X,region,xtype) = rCuFFTPlan(plan,X,size(X),region,xtype)
59-
6068
const xtypenames = Dict{cufftType,String}(CUFFT_R2C => "real-to-complex",
6169
CUFFT_C2R => "complex-to-real",
6270
CUFFT_C2C => "complex",
@@ -83,12 +91,6 @@ function Base.show(io::IO, p::CuFFTPlan{T,K,inplace}) where {T,K,inplace}
8391
showfftdims(io, p.sz, T)
8492
end
8593

86-
Base.unsafe_convert(::Type{cufftHandle}, p::CuFFTPlan) = p.plan
87-
88-
Base.convert(::Type{cufftHandle}, p::CuFFTPlan) = p.plan
89-
90-
destroy_plan(plan::CuFFTPlan) = cufftDestroy(plan)
91-
9294
set_stream(plan::CuFFTPlan, stream::CuStream) = cufftSetStream(plan, stream)
9395

9496
Base.size(p::CuFFTPlan) = p.sz
@@ -97,26 +99,33 @@ Base.size(p::CuFFTPlan) = p.sz
9799
## plan methods
98100

99101
# Note: we don't implement padded storage dimensions
100-
function _mkplan(xtype, xdims, region)
102+
function create_plan(xtype, xdims, region)
101103
nrank = length(region)
102104
sz = [xdims[i] for i in region]
103105
csz = copy(sz)
104106
csz[1] = div(sz[1],2) + 1
105107
batch = prod(xdims) ÷ prod(sz)
106108

107-
pp = Ref{cufftHandle}()
109+
# initialize the plan handle
110+
handle_ref = Ref{cufftHandle}()
111+
cufftCreate(handle_ref)
112+
handle = handle_ref[]
113+
cufftSetAutoAllocation(handle, 0)
114+
115+
# make the plan
116+
worksize_ref = Ref{Csize_t}()
108117
if (nrank == 1) && (batch == 1)
109-
cufftPlan1d(pp, sz[1], xtype, 1)
118+
cufftMakePlan1d(handle, sz[1], xtype, 1, worksize_ref)
110119
elseif (nrank == 2) && (batch == 1)
111-
cufftPlan2d(pp, sz[2], sz[1], xtype)
120+
cufftMakePlan2d(handle, sz[2], sz[1], xtype, worksize_ref)
112121
elseif (nrank == 3) && (batch == 1)
113-
cufftPlan3d(pp, sz[3], sz[2], sz[1], xtype)
122+
cufftMakePlan3d(handle, sz[3], sz[2], sz[1], xtype, worksize_ref)
114123
else
115124
rsz = (length(sz) > 1) ? rsz = reverse(sz) : sz
116125
if ((region...,) == ((1:nrank)...,))
117126
# handle simple case ... simply! (for robustness)
118-
cufftPlanMany(pp, nrank, Cint[rsz...], C_NULL, 1, 1, C_NULL, 1, 1,
119-
xtype, batch)
127+
cufftMakePlanMany(handle, nrank, Cint[rsz...], C_NULL, 1, 1, C_NULL, 1, 1,
128+
xtype, batch, worksize_ref)
120129
else
121130
if nrank==1 || all(diff(collect(region)) .== 1)
122131
# _stride: successive elements in innermost dimension
@@ -207,12 +216,17 @@ function _mkplan(xtype, xdims, region)
207216
inembed = cnembed
208217
end
209218
end
210-
cufftPlanMany(pp, nrank, Cint[rsz...],
211-
inembed, istride, idist, onembed, ostride, odist,
212-
xtype, batch)
219+
cufftMakePlanMany(handle, nrank, Cint[rsz...],
220+
inembed, istride, idist, onembed, ostride, odist,
221+
xtype, batch, worksize_ref)
213222
end
214223
end
215-
pp[]
224+
225+
# assign the workarea
226+
workarea = CuArray{Int8}(undef, worksize_ref[])
227+
cufftSetWorkArea(handle, workarea)
228+
229+
handle, workarea
216230
end
217231

218232
# promote to a complex floating-point type (out-of-place only),
@@ -238,19 +252,19 @@ function plan_fft!(X::CuArray{T,N}, region) where {T<:cufftComplexes,N}
238252
inplace = true
239253
xtype = (T == cufftComplex) ? CUFFT_C2C : CUFFT_Z2Z
240254

241-
pp = _mkplan(xtype, size(X), region)
255+
pp = create_plan(xtype, size(X), region)
242256

243-
cCuFFTPlan{T,K,inplace,N}(pp, X, size(X), region, xtype)
257+
cCuFFTPlan{T,K,inplace,N}(pp..., X, size(X), region, xtype)
244258
end
245259

246260
function plan_bfft!(X::CuArray{T,N}, region) where {T<:cufftComplexes,N}
247261
K = CUFFT_INVERSE
248262
inplace = true
249263
xtype = (T == cufftComplex) ? CUFFT_C2C : CUFFT_Z2Z
250264

251-
pp = _mkplan(xtype, size(X), region)
265+
pp = create_plan(xtype, size(X), region)
252266

253-
cCuFFTPlan{T,K,inplace,N}(pp, X, size(X), region, xtype)
267+
cCuFFTPlan{T,K,inplace,N}(pp..., X, size(X), region, xtype)
254268
end
255269

256270
# out-of-place complex
@@ -259,19 +273,19 @@ function plan_fft(X::CuArray{T,N}, region) where {T<:cufftComplexes,N}
259273
xtype = (T == cufftComplex) ? CUFFT_C2C : CUFFT_Z2Z
260274
inplace = false
261275

262-
pp = _mkplan(xtype, size(X), region)
276+
pp = create_plan(xtype, size(X), region)
263277

264-
cCuFFTPlan{T,K,inplace,N}(pp, X, size(X), region, xtype)
278+
cCuFFTPlan{T,K,inplace,N}(pp..., X, size(X), region, xtype)
265279
end
266280

267281
function plan_bfft(X::CuArray{T,N}, region) where {T<:cufftComplexes,N}
268282
K = CUFFT_INVERSE
269283
inplace = false
270284
xtype = (T == cufftComplex) ? CUFFT_C2C : CUFFT_Z2Z
271285

272-
pp = _mkplan(xtype, size(X), region)
286+
pp = create_plan(xtype, size(X), region)
273287

274-
cCuFFTPlan{T,K,inplace,N}(pp, X, size(X), region, xtype)
288+
cCuFFTPlan{T,K,inplace,N}(pp..., X, size(X), region, xtype)
275289
end
276290

277291
# out-of-place real-to-complex
@@ -280,12 +294,12 @@ function plan_rfft(X::CuArray{T,N}, region) where {T<:cufftReals,N}
280294
inplace = false
281295
xtype = (T == cufftReal) ? CUFFT_R2C : CUFFT_D2Z
282296

283-
pp = _mkplan(xtype, size(X), region)
297+
pp = create_plan(xtype, size(X), region)
284298

285299
ydims = collect(size(X))
286300
ydims[region[1]] = div(ydims[region[1]],2)+1
287301

288-
rCuFFTPlan{T,K,inplace,N}(pp, X, (ydims...,), region, xtype)
302+
rCuFFTPlan{T,K,inplace,N}(pp..., X, (ydims...,), region, xtype)
289303
end
290304

291305
function plan_brfft(X::CuArray{T,N}, d::Integer, region::Any) where {T<:cufftComplexes,N}
@@ -295,26 +309,26 @@ function plan_brfft(X::CuArray{T,N}, d::Integer, region::Any) where {T<:cufftCom
295309
ydims = collect(size(X))
296310
ydims[region[1]] = d
297311

298-
pp = _mkplan(xtype, (ydims...,), region)
312+
pp = create_plan(xtype, (ydims...,), region)
299313

300-
rCuFFTPlan{T,K,inplace,N}(pp, X, (ydims...,), region, xtype)
314+
rCuFFTPlan{T,K,inplace,N}(pp..., X, (ydims...,), region, xtype)
301315
end
302316

303317
# FIXME: plan_inv methods allocate needlessly (to provide type parameters)
304318
# Perhaps use FakeArray types to avoid this.
305319

306320
function plan_inv(p::cCuFFTPlan{T,CUFFT_FORWARD,inplace,N}) where {T,N,inplace}
307321
X = CuArray{T}(undef, p.sz)
308-
pp = _mkplan(p.xtype, p.sz, p.region)
309-
ScaledPlan(cCuFFTPlan{T,CUFFT_INVERSE,inplace,N}(pp, X, p.sz, p.region,
322+
pp = create_plan(p.xtype, p.sz, p.region)
323+
ScaledPlan(cCuFFTPlan{T,CUFFT_INVERSE,inplace,N}(pp..., X, p.sz, p.region,
310324
p.xtype),
311325
normalization(X, p.region))
312326
end
313327

314328
function plan_inv(p::cCuFFTPlan{T,CUFFT_INVERSE,inplace,N}) where {T,N,inplace}
315329
X = CuArray{T}(undef, p.sz)
316-
pp = _mkplan(p.xtype, p.sz, p.region)
317-
ScaledPlan(cCuFFTPlan{T,CUFFT_FORWARD,inplace,N}(pp, X, p.sz, p.region,
330+
pp = create_plan(p.xtype, p.sz, p.region)
331+
ScaledPlan(cCuFFTPlan{T,CUFFT_FORWARD,inplace,N}(pp..., X, p.sz, p.region,
318332
p.xtype),
319333
normalization(X, p.region))
320334
end
@@ -324,9 +338,8 @@ function plan_inv(p::rCuFFTPlan{T,CUFFT_INVERSE,inplace,N}
324338
X = CuArray{real(T)}(undef, p.osz)
325339
Y = CuArray{T}(undef, p.sz)
326340
xtype = p.xtype == CUFFT_C2R ? CUFFT_R2C : CUFFT_D2Z
327-
pp = _mkplan(xtype, p.osz, p.region)
328-
ScaledPlan(rCuFFTPlan{real(T),CUFFT_FORWARD,inplace,N}(pp, X, p.sz, p.region,
329-
xtype),
341+
pp = create_plan(xtype, p.osz, p.region)
342+
ScaledPlan(rCuFFTPlan{real(T),CUFFT_FORWARD,inplace,N}(pp..., X, p.sz, p.region, xtype),
330343
normalization(X, p.region))
331344
end
332345

@@ -335,8 +348,8 @@ function plan_inv(p::rCuFFTPlan{T,CUFFT_FORWARD,inplace,N}
335348
X = CuArray{complex(T)}(undef, p.osz)
336349
Y = CuArray{T}(undef, p.sz)
337350
xtype = p.xtype == CUFFT_R2C ? CUFFT_C2R : CUFFT_Z2D
338-
pp = _mkplan(xtype, p.sz, p.region)
339-
ScaledPlan(rCuFFTPlan{complex(T),CUFFT_INVERSE,inplace,N}(pp, X, p.sz,
351+
pp = create_plan(xtype, p.sz, p.region)
352+
ScaledPlan(rCuFFTPlan{complex(T),CUFFT_INVERSE,inplace,N}(pp..., X, p.sz,
340353
p.region, xtype),
341354
normalization(Y, p.region))
342355
end

0 commit comments

Comments
 (0)