@@ -25,32 +25,34 @@ Base.:(*)(p::ScaledPlan, x::DenseCuArray) = rmul!(p.p * x, p.scale)
25
25
26
26
# N is the number of dimensions
27
27
28
- mutable struct CuFFTPlan{T<: cufftNumber ,S<: cufftNumber ,K,inplace,N} <: Plan{S}
28
+ mutable struct CuFFTPlan{T<: cufftNumber ,S<: cufftNumber ,K,inplace,N,R,B } <: Plan{S}
29
29
# handle to Cuda low level plan. Note that this plan sometimes has lower dimensions
30
30
# to handle more transform cases such as individual directions
31
31
handle:: cufftHandle
32
32
ctx:: CuContext
33
33
stream:: CuStream
34
34
input_size:: NTuple{N,Int} # Julia size of input array
35
35
output_size:: NTuple{N,Int} # Julia size of output array
36
- region:: Any
36
+ region:: NTuple{R,Int}
37
+ buffer:: B # buffer for out-of-place complex-to-real FFT, or `nothing` if not needed
37
38
pinv:: ScaledPlan{T} # required by AbstractFFTs API, will be defined by AbstractFFTs if needed
38
39
39
- function CuFFTPlan {T,S,K,inplace,N} (handle:: cufftHandle ,
40
- input_size:: NTuple{N,Int} , output_size:: NTuple{N,Int} , region
41
- ) where {T<: cufftNumber ,S<: cufftNumber ,K,inplace,N}
40
+ function CuFFTPlan {T,S,K,inplace,N,R,B} (handle:: cufftHandle ,
41
+ input_size:: NTuple{N,Int} , output_size:: NTuple{N,Int} ,
42
+ region:: NTuple{R,Int} , buffer:: B
43
+ ) where {T<: cufftNumber ,S<: cufftNumber ,K,inplace,N,R,B}
42
44
abs (K) == 1 || throw (ArgumentError (" FFT direction must be either -1 (forward) or +1 (inverse)" ))
43
45
inplace isa Bool || throw (ArgumentError (" FFT inplace argument must be a Bool" ))
44
- p = new {T,S,K,inplace,N} (handle, context (), stream (), input_size, output_size, region)
46
+ p = new {T,S,K,inplace,N,R,B } (handle, context (), stream (), input_size, output_size, region, buffer )
45
47
finalizer (unsafe_free!, p)
46
48
p
47
49
end
48
50
end
49
51
50
- function CuFFTPlan {T,S,K,inplace,N} (handle:: cufftHandle , X:: DenseCuArray{S,N} ,
51
- sizey:: NTuple{N,Int} , region,
52
- ) where {T<: cufftNumber ,S<: cufftNumber ,K,inplace,N}
53
- CuFFTPlan {T,S,K,inplace,N} (handle, size (X), sizey, region)
52
+ function CuFFTPlan {T,S,K,inplace,N,R,B } (handle:: cufftHandle , X:: DenseCuArray{S,N} ,
53
+ sizey:: NTuple{N,Int} , region:: NTuple{R,Int} , buffer :: B
54
+ ) where {T<: cufftNumber ,S<: cufftNumber ,K,inplace,N,R,B }
55
+ CuFFTPlan {T,S,K,inplace,N,R,B } (handle, size (X), sizey, region, buffer )
54
56
end
55
57
56
58
function CUDA. unsafe_free! (plan:: CuFFTPlan )
@@ -60,6 +62,9 @@ function CUDA.unsafe_free!(plan::CuFFTPlan)
60
62
end
61
63
plan. handle = C_NULL
62
64
end
65
+ if ! isnothing (plan. buffer)
66
+ CUDA. unsafe_free! (plan. buffer)
67
+ end
63
68
end
64
69
65
70
function showfftdims (io, sz, T)
@@ -151,103 +156,116 @@ end
151
156
function plan_fft! (X:: DenseCuArray{T,N} , region) where {T<: cufftComplexes ,N}
152
157
K = CUFFT_FORWARD
153
158
inplace = true
154
- region = Tuple (region)
159
+ R = length (region)
160
+ region = NTuple {R,Int} (region)
155
161
156
162
md = plan_max_dims (region, size (X))
157
163
sizex = size (X)[1 : md]
158
164
handle = cufftGetPlan (T, T, sizex, region)
159
165
160
- CuFFTPlan {T,T,K,inplace,N} (handle, X, size (X), region)
166
+ CuFFTPlan {T,T,K,inplace,N,R,Nothing } (handle, X, size (X), region, nothing )
161
167
end
162
168
163
-
164
169
function plan_bfft! (X:: DenseCuArray{T,N} , region) where {T<: cufftComplexes ,N}
165
170
K = CUFFT_INVERSE
166
171
inplace = true
167
- region = Tuple (region)
172
+ R = length (region)
173
+ region = NTuple {R,Int} (region)
168
174
169
175
md = plan_max_dims (region, size (X))
170
176
sizex = size (X)[1 : md]
171
177
handle = cufftGetPlan (T, T, sizex, region)
172
178
173
- CuFFTPlan {T,T,K,inplace,N} (handle, X, size (X), region)
179
+ CuFFTPlan {T,T,K,inplace,N,R,Nothing } (handle, X, size (X), region, nothing )
174
180
end
175
181
176
182
# out-of-place complex
177
183
function plan_fft (X:: DenseCuArray{T,N} , region) where {T<: cufftComplexes ,N}
178
184
K = CUFFT_FORWARD
179
185
inplace = false
180
- region = Tuple (region)
186
+ R = length (region)
187
+ region = NTuple {R,Int} (region)
181
188
182
189
md = plan_max_dims (region,size (X))
183
190
sizex = size (X)[1 : md]
184
191
handle = cufftGetPlan (T, T, sizex, region)
185
192
186
- CuFFTPlan {T,T,K,inplace,N} (handle, X, size (X), region)
193
+ CuFFTPlan {T,T,K,inplace,N,R,Nothing } (handle, X, size (X), region, nothing )
187
194
end
188
195
189
196
function plan_bfft (X:: DenseCuArray{T,N} , region) where {T<: cufftComplexes ,N}
190
197
K = CUFFT_INVERSE
191
198
inplace = false
192
- region = Tuple (region)
199
+ R = length (region)
200
+ region = NTuple {R,Int} (region)
193
201
194
202
md = plan_max_dims (region,size (X))
195
203
sizex = size (X)[1 : md]
196
204
handle = cufftGetPlan (T, T, sizex, region)
197
205
198
- CuFFTPlan {T,T,K,inplace,N} (handle, size (X), size (X), region)
206
+ CuFFTPlan {T,T,K,inplace,N,R,Nothing } (handle, size (X), size (X), region, nothing )
199
207
end
200
208
201
209
# out-of-place real-to-complex
202
210
function plan_rfft (X:: DenseCuArray{T,N} , region) where {T<: cufftReals ,N}
203
211
K = CUFFT_FORWARD
204
212
inplace = false
205
- region = Tuple (region)
213
+ R = length (region)
214
+ region = NTuple {R,Int} (region)
206
215
207
216
md = plan_max_dims (region,size (X))
208
- # X = front_view(X, md)
209
217
sizex = size (X)[1 : md]
210
218
211
219
handle = cufftGetPlan (complex (T), T, sizex, region)
212
220
213
221
ydims = collect (size (X))
214
- ydims[region[1 ]] = div (ydims[region[1 ]],2 ) + 1
222
+ ydims[region[1 ]] = div (ydims[region[1 ]], 2 ) + 1
215
223
216
- CuFFTPlan {complex(T),T,K,inplace,N} (handle, size (X), (ydims... ,), region)
224
+ # The buffer is not needed for real-to-complex (`mul!`),
225
+ # but it’s required for complex-to-real (`ldiv!`).
226
+ buffer = CuArray {complex(T)} (undef, ydims... )
227
+ B = typeof (buffer)
228
+
229
+ CuFFTPlan {complex(T),T,K,inplace,N,R,B} (handle, size (X), (ydims... ,), region, buffer)
217
230
end
218
231
219
- function plan_brfft (X:: DenseCuArray{T,N} , d:: Integer , region:: Any ) where {T<: cufftComplexes ,N}
232
+ # out-of-place complex-to-real
233
+ function plan_brfft (X:: DenseCuArray{T,N} , d:: Integer , region) where {T<: cufftComplexes ,N}
220
234
K = CUFFT_INVERSE
221
235
inplace = false
222
- region = Tuple (region)
236
+ R = length (region)
237
+ region = NTuple {R,Int} (region)
223
238
224
239
ydims = collect (size (X))
225
240
ydims[region[1 ]] = d
226
241
227
242
handle = cufftGetPlan (real (T), T, (ydims... ,), region)
228
243
229
- CuFFTPlan {real(T),T,K,inplace,N} (handle, size (X), (ydims... ,), region)
244
+ buffer = CuArray {T} (undef, size (X))
245
+ B = typeof (buffer)
246
+
247
+ CuFFTPlan {real(T),T,K,inplace,N,R,B} (handle, size (X), (ydims... ,), region, buffer)
230
248
end
231
249
232
250
233
251
# FIXME : plan_inv methods allocate needlessly (to provide type parameters)
234
252
# Perhaps use FakeArray types to avoid this.
235
253
236
- function plan_inv (p:: CuFFTPlan{T,S,CUFFT_INVERSE,inplace,N}
237
- ) where {T<: cufftNumber ,S<: cufftNumber ,N,inplace }
254
+ function plan_inv (p:: CuFFTPlan{T,S,CUFFT_INVERSE,inplace,N,R,B }
255
+ ) where {T<: cufftNumber ,S<: cufftNumber ,inplace,N,R,B }
238
256
md_osz = plan_max_dims (p. region, p. output_size)
239
257
sz_X = p. output_size[1 : md_osz]
240
258
handle = cufftGetPlan (S, T, sz_X, p. region)
241
- ScaledPlan (CuFFTPlan {S,T,CUFFT_FORWARD,inplace,N} (handle, p. output_size, p. input_size, p. region),
259
+ ScaledPlan (CuFFTPlan {S,T,CUFFT_FORWARD,inplace,N,R,B } (handle, p. output_size, p. input_size, p. region, p . buffer ),
242
260
normalization (real (T), p. output_size, p. region))
243
261
end
244
262
245
- function plan_inv (p:: CuFFTPlan{T,S,CUFFT_FORWARD,inplace,N}
246
- ) where {T<: cufftNumber ,S<: cufftNumber ,N,inplace }
263
+ function plan_inv (p:: CuFFTPlan{T,S,CUFFT_FORWARD,inplace,N,R,B }
264
+ ) where {T<: cufftNumber ,S<: cufftNumber ,inplace,N,R,B }
247
265
md_isz = plan_max_dims (p. region, p. input_size)
248
266
sz_Y = p. input_size[1 : md_isz]
249
267
handle = cufftGetPlan (S, T, sz_Y, p. region)
250
- ScaledPlan (CuFFTPlan {S,T,CUFFT_INVERSE,inplace,N} (handle, p. output_size, p. input_size, p. region),
268
+ ScaledPlan (CuFFTPlan {S,T,CUFFT_INVERSE,inplace,N,R,B } (handle, p. output_size, p. input_size, p. region, p . buffer ),
251
269
normalization (real (S), p. input_size, p. region))
252
270
end
253
271
@@ -309,10 +327,14 @@ function LinearAlgebra.mul!(y::DenseCuArray{T}, p::CuFFTPlan{T,S,K,inplace}, x::
309
327
) where {T,S,K,inplace}
310
328
assert_applicable (p, x, y)
311
329
if ! inplace && T<: Real
312
- # Out-of-place complex-to-real FFT will always overwrite input buffer.
313
- x = copy (x)
330
+ # Out-of-place complex-to-real FFT will always overwrite input x.
331
+ # We copy the input x in an auxiliary buffer.
332
+ z = p. buffer
333
+ copyto! (z, x)
334
+ else
335
+ z = x
314
336
end
315
- unsafe_execute_trailing! (p, x , y)
337
+ unsafe_execute_trailing! (p, z , y)
316
338
y
317
339
end
318
340
@@ -323,13 +345,21 @@ function Base.:(*)(p::CuFFTPlan{T,S,K,true}, x::DenseCuArray{S}) where {T,S,K}
323
345
end
324
346
325
347
function Base.:(* )(p:: CuFFTPlan{T,S,K,false} , x:: DenseCuArray{S1,M} ) where {T,S,K,S1,M}
326
- if S1 != S || T<: Real
327
- # Convert to the expected input type. Also,
328
- # Out-of-place complex-to-real FFT will always overwrite input buffer.
329
- x = copy1 (S, x)
348
+ if T<: Real
349
+ # Out-of-place complex-to-real FFT will always overwrite input x.
350
+ # We copy the input x in an auxiliary buffer.
351
+ z = p. buffer
352
+ copyto! (z, x)
353
+ else
354
+ if S1 != S
355
+ # Convert to the expected input type.
356
+ z = copy1 (S, x)
357
+ else
358
+ z = x
359
+ end
330
360
end
331
- assert_applicable (p, x )
361
+ assert_applicable (p, z )
332
362
y = CuArray {T,M} (undef, p. output_size)
333
- unsafe_execute_trailing! (p, x , y)
363
+ unsafe_execute_trailing! (p, z , y)
334
364
y
335
365
end
0 commit comments