@@ -152,12 +152,23 @@ end
152
152
# region is an iterable subset of dimensions
153
153
# spec. an integer, range, tuple, or array
154
154
155
+ # try to constant-propagate the `region` argument when it is not a tuple. This helps with
156
+ # inference of calls like plan_fft(X), which is translated by AbstractFFTs.jl into
157
+ # plan_fft(X, 1:ndims(X)).
158
+ for f in (:plan_fft! , :plan_bfft! , :plan_fft , :plan_bfft )
159
+ @eval begin
160
+ Base. @constprop :aggressive function $f (X:: DenseCuArray{T,N} , region) where {T<: cufftComplexes ,N}
161
+ R = length (region)
162
+ region = NTuple {R,Int} (region)
163
+ $ f (X, region)
164
+ end
165
+ end
166
+ end
167
+
155
168
# inplace complex
156
- function plan_fft! (X:: DenseCuArray{T,N} , region) where {T<: cufftComplexes ,N}
169
+ function plan_fft! (X:: DenseCuArray{T,N} , region:: NTuple{R,Int} ) where {T<: cufftComplexes ,N,R }
157
170
K = CUFFT_FORWARD
158
171
inplace = true
159
- R = length (region)
160
- region = NTuple {R,Int} (region)
161
172
162
173
md = plan_max_dims (region, size (X))
163
174
sizex = size (X)[1 : md]
@@ -166,11 +177,9 @@ function plan_fft!(X::DenseCuArray{T,N}, region) where {T<:cufftComplexes,N}
166
177
CuFFTPlan {T,T,K,inplace,N,R,Nothing} (handle, X, size (X), region, nothing )
167
178
end
168
179
169
- function plan_bfft! (X:: DenseCuArray{T,N} , region) where {T<: cufftComplexes ,N}
180
+ function plan_bfft! (X:: DenseCuArray{T,N} , region:: NTuple{R,Int} ) where {T<: cufftComplexes ,N,R }
170
181
K = CUFFT_INVERSE
171
182
inplace = true
172
- R = length (region)
173
- region = NTuple {R,Int} (region)
174
183
175
184
md = plan_max_dims (region, size (X))
176
185
sizex = size (X)[1 : md]
@@ -180,11 +189,9 @@ function plan_bfft!(X::DenseCuArray{T,N}, region) where {T<:cufftComplexes,N}
180
189
end
181
190
182
191
# out-of-place complex
183
- function plan_fft (X:: DenseCuArray{T,N} , region) where {T<: cufftComplexes ,N}
192
+ function plan_fft (X:: DenseCuArray{T,N} , region:: NTuple{R,Int} ) where {T<: cufftComplexes ,N,R }
184
193
K = CUFFT_FORWARD
185
194
inplace = false
186
- R = length (region)
187
- region = NTuple {R,Int} (region)
188
195
189
196
md = plan_max_dims (region,size (X))
190
197
sizex = size (X)[1 : md]
@@ -193,11 +200,9 @@ function plan_fft(X::DenseCuArray{T,N}, region) where {T<:cufftComplexes,N}
193
200
CuFFTPlan {T,T,K,inplace,N,R,Nothing} (handle, X, size (X), region, nothing )
194
201
end
195
202
196
- function plan_bfft (X:: DenseCuArray{T,N} , region) where {T<: cufftComplexes ,N}
203
+ function plan_bfft (X:: DenseCuArray{T,N} , region:: NTuple{R,Int} ) where {T<: cufftComplexes ,N,R }
197
204
K = CUFFT_INVERSE
198
205
inplace = false
199
- R = length (region)
200
- region = NTuple {R,Int} (region)
201
206
202
207
md = plan_max_dims (region,size (X))
203
208
sizex = size (X)[1 : md]
@@ -207,19 +212,23 @@ function plan_bfft(X::DenseCuArray{T,N}, region) where {T<:cufftComplexes,N}
207
212
end
208
213
209
214
# out-of-place real-to-complex
210
- function plan_rfft (X:: DenseCuArray{T,N} , region) where {T<: cufftReals ,N}
211
- K = CUFFT_FORWARD
212
- inplace = false
215
+ Base. @constprop :aggressive function plan_rfft (X:: DenseCuArray{T,N} , region) where {T<: cufftReals ,N}
213
216
R = length (region)
214
217
region = NTuple {R,Int} (region)
218
+ plan_rfft (X, region)
219
+ end
220
+
221
+ function plan_rfft (X:: DenseCuArray{T,N} , region:: NTuple{R,Int} ) where {T<: cufftReals ,N,R}
222
+ K = CUFFT_FORWARD
223
+ inplace = false
215
224
216
225
md = plan_max_dims (region,size (X))
217
226
sizex = size (X)[1 : md]
218
227
219
228
handle = cufftGetPlan (complex (T), T, sizex, region)
220
229
221
- ydims = collect ( size (X) )
222
- ydims[region[ 1 ]] = div (ydims [region[1 ]], 2 ) + 1
230
+ xdims = size (X)
231
+ ydims = Base . setindex (xdims, div (xdims [region[1 ]], 2 ) + 1 , region[ 1 ])
223
232
224
233
# The buffer is not needed for real-to-complex (`mul!`),
225
234
# but it’s required for complex-to-real (`ldiv!`).
@@ -230,21 +239,24 @@ function plan_rfft(X::DenseCuArray{T,N}, region) where {T<:cufftReals,N}
230
239
end
231
240
232
241
# out-of-place complex-to-real
233
- function plan_brfft (X:: DenseCuArray{T,N} , d:: Integer , region) where {T<: cufftComplexes ,N}
234
- K = CUFFT_INVERSE
235
- inplace = false
242
+ Base. @constprop :aggressive function plan_brfft (X:: DenseCuArray{T,N} , d:: Integer , region) where {T<: cufftComplexes ,N}
236
243
R = length (region)
237
244
region = NTuple {R,Int} (region)
245
+ plan_brfft (X, d, region)
246
+ end
238
247
239
- ydims = collect (size (X))
240
- ydims[region[1 ]] = d
248
+ function plan_brfft (X:: DenseCuArray{T,N} , d:: Integer , region:: NTuple{R,Int} ) where {T<: cufftComplexes ,N,R}
249
+ K = CUFFT_INVERSE
250
+ inplace = false
241
251
242
- handle = cufftGetPlan (real (T), T, (ydims... ,), region)
252
+ xdims = size (X)
253
+ ydims = Base. setindex (xdims, d, region[1 ])
254
+ handle = cufftGetPlan (real (T), T, ydims, region)
243
255
244
256
buffer = CuArray {T} (undef, size (X))
245
257
B = typeof (buffer)
246
258
247
- CuFFTPlan {real(T),T,K,inplace,N,R,B} (handle, size (X), ( ydims... ,) , region, buffer)
259
+ CuFFTPlan {real(T),T,K,inplace,N,R,B} (handle, size (X), ydims, region, buffer)
248
260
end
249
261
250
262
0 commit comments