@@ -17,46 +17,54 @@ Base.:(*)(p::ScaledPlan, x::CuArray) = rmul!(p.p * x, p.scale)
17
17
18
18
abstract type CuFFTPlan{T<: cufftNumber , K, inplace} <: Plan{T} end
19
19
20
+ Base. unsafe_convert (:: Type{cufftHandle} , p:: CuFFTPlan ) = p. handle
21
+
22
+ # for some reason, cufftHandle is an integer and not a pointer...
23
+ Base. convert (:: Type{cufftHandle} , p:: CuFFTPlan ) = Base. unsafe_convert (cufftHandle, p)
24
+
25
+ function unsafe_free! (plan:: CuFFTPlan )
26
+ cufftDestroy (plan)
27
+ unsafe_free! (plan. workarea)
28
+ end
29
+
20
30
mutable struct cCuFFTPlan{T<: cufftNumber ,K,inplace,N} <: CuFFTPlan{T,K,inplace}
21
- plan:: cufftHandle
31
+ handle:: cufftHandle
32
+ workarea:: CuVector{Int8}
22
33
sz:: NTuple{N,Int} # Julia size of input array
23
34
osz:: NTuple{N,Int} # Julia size of output array
24
35
xtype:: cufftType
25
36
region:: Any
26
37
pinv:: ScaledPlan # required by AbstractFFT API
27
38
28
- function cCuFFTPlan {T,K,inplace,N} (plan :: cufftHandle , X :: CuArray{T,N } ,
29
- sizey:: Tuple , region, xtype
39
+ function cCuFFTPlan {T,K,inplace,N} (handle :: cufftHandle , workarea :: CuVector{Int8 } ,
40
+ X :: CuArray{T,N} , sizey:: Tuple , region, xtype
30
41
) where {T<: cufftNumber ,K,inplace,N}
31
42
# maybe enforce consistency of sizey
32
- p = new (plan , size (X), sizey, xtype, region)
33
- finalizer (destroy_plan , p)
43
+ p = new (handle, workarea , size (X), sizey, xtype, region)
44
+ finalizer (unsafe_free! , p)
34
45
p
35
46
end
36
47
end
37
48
38
- cCuFFTPlan (plan,X,region,xtype) = cCuFFTPlan (plan,X,size (X),region,xtype)
39
-
40
49
mutable struct rCuFFTPlan{T<: cufftNumber ,K,inplace,N} <: CuFFTPlan{T,K,inplace}
41
- plan:: cufftHandle
50
+ handle:: cufftHandle
51
+ workarea:: CuVector{Int8}
42
52
sz:: NTuple{N,Int} # Julia size of input array
43
53
osz:: NTuple{N,Int} # Julia size of output array
44
54
xtype:: cufftType
45
55
region:: Any
46
56
pinv:: ScaledPlan # required by AbstractFFT API
47
57
48
- function rCuFFTPlan {T,K,inplace,N} (plan :: cufftHandle , X :: CuArray{T,N } ,
49
- sizey:: Tuple , region, xtype
58
+ function rCuFFTPlan {T,K,inplace,N} (handle :: cufftHandle , workarea :: CuVector{Int8 } ,
59
+ X :: CuArray{T,N} , sizey:: Tuple , region, xtype
50
60
) where {T<: cufftNumber ,K,inplace,N}
51
61
# maybe enforce consistency of sizey
52
- p = new (plan , size (X), sizey, xtype, region)
53
- finalizer (destroy_plan , p)
62
+ p = new (handle, workarea , size (X), sizey, xtype, region)
63
+ finalizer (unsafe_free! , p)
54
64
p
55
65
end
56
66
end
57
67
58
- rCuFFTPlan (plan,X,region,xtype) = rCuFFTPlan (plan,X,size (X),region,xtype)
59
-
60
68
const xtypenames = Dict {cufftType,String} (CUFFT_R2C => " real-to-complex" ,
61
69
CUFFT_C2R => " complex-to-real" ,
62
70
CUFFT_C2C => " complex" ,
@@ -83,12 +91,6 @@ function Base.show(io::IO, p::CuFFTPlan{T,K,inplace}) where {T,K,inplace}
83
91
showfftdims (io, p. sz, T)
84
92
end
85
93
86
- Base. unsafe_convert (:: Type{cufftHandle} , p:: CuFFTPlan ) = p. plan
87
-
88
- Base. convert (:: Type{cufftHandle} , p:: CuFFTPlan ) = p. plan
89
-
90
- destroy_plan (plan:: CuFFTPlan ) = cufftDestroy (plan)
91
-
92
94
set_stream (plan:: CuFFTPlan , stream:: CuStream ) = cufftSetStream (plan, stream)
93
95
94
96
Base. size (p:: CuFFTPlan ) = p. sz
@@ -97,26 +99,33 @@ Base.size(p::CuFFTPlan) = p.sz
97
99
# # plan methods
98
100
99
101
# Note: we don't implement padded storage dimensions
100
- function _mkplan (xtype, xdims, region)
102
+ function create_plan (xtype, xdims, region)
101
103
nrank = length (region)
102
104
sz = [xdims[i] for i in region]
103
105
csz = copy (sz)
104
106
csz[1 ] = div (sz[1 ],2 ) + 1
105
107
batch = prod (xdims) ÷ prod (sz)
106
108
107
- pp = Ref {cufftHandle} ()
109
+ # initialize the plan handle
110
+ handle_ref = Ref {cufftHandle} ()
111
+ cufftCreate (handle_ref)
112
+ handle = handle_ref[]
113
+ cufftSetAutoAllocation (handle, 0 )
114
+
115
+ # make the plan
116
+ worksize_ref = Ref {Csize_t} ()
108
117
if (nrank == 1 ) && (batch == 1 )
109
- cufftPlan1d (pp , sz[1 ], xtype, 1 )
118
+ cufftMakePlan1d (handle , sz[1 ], xtype, 1 , worksize_ref )
110
119
elseif (nrank == 2 ) && (batch == 1 )
111
- cufftPlan2d (pp , sz[2 ], sz[1 ], xtype)
120
+ cufftMakePlan2d (handle , sz[2 ], sz[1 ], xtype, worksize_ref )
112
121
elseif (nrank == 3 ) && (batch == 1 )
113
- cufftPlan3d (pp , sz[3 ], sz[2 ], sz[1 ], xtype)
122
+ cufftMakePlan3d (handle , sz[3 ], sz[2 ], sz[1 ], xtype, worksize_ref )
114
123
else
115
124
rsz = (length (sz) > 1 ) ? rsz = reverse (sz) : sz
116
125
if ((region... ,) == ((1 : nrank). .. ,))
117
126
# handle simple case ... simply! (for robustness)
118
- cufftPlanMany (pp , nrank, Cint[rsz... ], C_NULL , 1 , 1 , C_NULL , 1 , 1 ,
119
- xtype, batch)
127
+ cufftMakePlanMany (handle , nrank, Cint[rsz... ], C_NULL , 1 , 1 , C_NULL , 1 , 1 ,
128
+ xtype, batch, worksize_ref )
120
129
else
121
130
if nrank== 1 || all (diff (collect (region)) .== 1 )
122
131
# _stride: successive elements in innermost dimension
@@ -207,12 +216,17 @@ function _mkplan(xtype, xdims, region)
207
216
inembed = cnembed
208
217
end
209
218
end
210
- cufftPlanMany (pp , nrank, Cint[rsz... ],
211
- inembed, istride, idist, onembed, ostride, odist,
212
- xtype, batch)
219
+ cufftMakePlanMany (handle , nrank, Cint[rsz... ],
220
+ inembed, istride, idist, onembed, ostride, odist,
221
+ xtype, batch, worksize_ref )
213
222
end
214
223
end
215
- pp[]
224
+
225
+ # assign the workarea
226
+ workarea = CuArray {Int8} (undef, worksize_ref[])
227
+ cufftSetWorkArea (handle, workarea)
228
+
229
+ handle, workarea
216
230
end
217
231
218
232
# promote to a complex floating-point type (out-of-place only),
@@ -238,19 +252,19 @@ function plan_fft!(X::CuArray{T,N}, region) where {T<:cufftComplexes,N}
238
252
inplace = true
239
253
xtype = (T == cufftComplex) ? CUFFT_C2C : CUFFT_Z2Z
240
254
241
- pp = _mkplan (xtype, size (X), region)
255
+ pp = create_plan (xtype, size (X), region)
242
256
243
- cCuFFTPlan {T,K,inplace,N} (pp, X, size (X), region, xtype)
257
+ cCuFFTPlan {T,K,inplace,N} (pp... , X, size (X), region, xtype)
244
258
end
245
259
246
260
function plan_bfft! (X:: CuArray{T,N} , region) where {T<: cufftComplexes ,N}
247
261
K = CUFFT_INVERSE
248
262
inplace = true
249
263
xtype = (T == cufftComplex) ? CUFFT_C2C : CUFFT_Z2Z
250
264
251
- pp = _mkplan (xtype, size (X), region)
265
+ pp = create_plan (xtype, size (X), region)
252
266
253
- cCuFFTPlan {T,K,inplace,N} (pp, X, size (X), region, xtype)
267
+ cCuFFTPlan {T,K,inplace,N} (pp... , X, size (X), region, xtype)
254
268
end
255
269
256
270
# out-of-place complex
@@ -259,19 +273,19 @@ function plan_fft(X::CuArray{T,N}, region) where {T<:cufftComplexes,N}
259
273
xtype = (T == cufftComplex) ? CUFFT_C2C : CUFFT_Z2Z
260
274
inplace = false
261
275
262
- pp = _mkplan (xtype, size (X), region)
276
+ pp = create_plan (xtype, size (X), region)
263
277
264
- cCuFFTPlan {T,K,inplace,N} (pp, X, size (X), region, xtype)
278
+ cCuFFTPlan {T,K,inplace,N} (pp... , X, size (X), region, xtype)
265
279
end
266
280
267
281
function plan_bfft (X:: CuArray{T,N} , region) where {T<: cufftComplexes ,N}
268
282
K = CUFFT_INVERSE
269
283
inplace = false
270
284
xtype = (T == cufftComplex) ? CUFFT_C2C : CUFFT_Z2Z
271
285
272
- pp = _mkplan (xtype, size (X), region)
286
+ pp = create_plan (xtype, size (X), region)
273
287
274
- cCuFFTPlan {T,K,inplace,N} (pp, X, size (X), region, xtype)
288
+ cCuFFTPlan {T,K,inplace,N} (pp... , X, size (X), region, xtype)
275
289
end
276
290
277
291
# out-of-place real-to-complex
@@ -280,12 +294,12 @@ function plan_rfft(X::CuArray{T,N}, region) where {T<:cufftReals,N}
280
294
inplace = false
281
295
xtype = (T == cufftReal) ? CUFFT_R2C : CUFFT_D2Z
282
296
283
- pp = _mkplan (xtype, size (X), region)
297
+ pp = create_plan (xtype, size (X), region)
284
298
285
299
ydims = collect (size (X))
286
300
ydims[region[1 ]] = div (ydims[region[1 ]],2 )+ 1
287
301
288
- rCuFFTPlan {T,K,inplace,N} (pp, X, (ydims... ,), region, xtype)
302
+ rCuFFTPlan {T,K,inplace,N} (pp... , X, (ydims... ,), region, xtype)
289
303
end
290
304
291
305
function plan_brfft (X:: CuArray{T,N} , d:: Integer , region:: Any ) where {T<: cufftComplexes ,N}
@@ -295,26 +309,26 @@ function plan_brfft(X::CuArray{T,N}, d::Integer, region::Any) where {T<:cufftCom
295
309
ydims = collect (size (X))
296
310
ydims[region[1 ]] = d
297
311
298
- pp = _mkplan (xtype, (ydims... ,), region)
312
+ pp = create_plan (xtype, (ydims... ,), region)
299
313
300
- rCuFFTPlan {T,K,inplace,N} (pp, X, (ydims... ,), region, xtype)
314
+ rCuFFTPlan {T,K,inplace,N} (pp... , X, (ydims... ,), region, xtype)
301
315
end
302
316
303
317
# FIXME : plan_inv methods allocate needlessly (to provide type parameters)
304
318
# Perhaps use FakeArray types to avoid this.
305
319
306
320
function plan_inv (p:: cCuFFTPlan{T,CUFFT_FORWARD,inplace,N} ) where {T,N,inplace}
307
321
X = CuArray {T} (undef, p. sz)
308
- pp = _mkplan (p. xtype, p. sz, p. region)
309
- ScaledPlan (cCuFFTPlan {T,CUFFT_INVERSE,inplace,N} (pp, X, p. sz, p. region,
322
+ pp = create_plan (p. xtype, p. sz, p. region)
323
+ ScaledPlan (cCuFFTPlan {T,CUFFT_INVERSE,inplace,N} (pp... , X, p. sz, p. region,
310
324
p. xtype),
311
325
normalization (X, p. region))
312
326
end
313
327
314
328
function plan_inv (p:: cCuFFTPlan{T,CUFFT_INVERSE,inplace,N} ) where {T,N,inplace}
315
329
X = CuArray {T} (undef, p. sz)
316
- pp = _mkplan (p. xtype, p. sz, p. region)
317
- ScaledPlan (cCuFFTPlan {T,CUFFT_FORWARD,inplace,N} (pp, X, p. sz, p. region,
330
+ pp = create_plan (p. xtype, p. sz, p. region)
331
+ ScaledPlan (cCuFFTPlan {T,CUFFT_FORWARD,inplace,N} (pp... , X, p. sz, p. region,
318
332
p. xtype),
319
333
normalization (X, p. region))
320
334
end
@@ -324,9 +338,8 @@ function plan_inv(p::rCuFFTPlan{T,CUFFT_INVERSE,inplace,N}
324
338
X = CuArray {real(T)} (undef, p. osz)
325
339
Y = CuArray {T} (undef, p. sz)
326
340
xtype = p. xtype == CUFFT_C2R ? CUFFT_R2C : CUFFT_D2Z
327
- pp = _mkplan (xtype, p. osz, p. region)
328
- ScaledPlan (rCuFFTPlan {real(T),CUFFT_FORWARD,inplace,N} (pp, X, p. sz, p. region,
329
- xtype),
341
+ pp = create_plan (xtype, p. osz, p. region)
342
+ ScaledPlan (rCuFFTPlan {real(T),CUFFT_FORWARD,inplace,N} (pp... , X, p. sz, p. region, xtype),
330
343
normalization (X, p. region))
331
344
end
332
345
@@ -335,8 +348,8 @@ function plan_inv(p::rCuFFTPlan{T,CUFFT_FORWARD,inplace,N}
335
348
X = CuArray {complex(T)} (undef, p. osz)
336
349
Y = CuArray {T} (undef, p. sz)
337
350
xtype = p. xtype == CUFFT_R2C ? CUFFT_C2R : CUFFT_Z2D
338
- pp = _mkplan (xtype, p. sz, p. region)
339
- ScaledPlan (rCuFFTPlan {complex(T),CUFFT_INVERSE,inplace,N} (pp, X, p. sz,
351
+ pp = create_plan (xtype, p. sz, p. region)
352
+ ScaledPlan (rCuFFTPlan {complex(T),CUFFT_INVERSE,inplace,N} (pp... , X, p. sz,
340
353
p. region, xtype),
341
354
normalization (Y, p. region))
342
355
end
0 commit comments