@@ -136,47 +136,63 @@ for (f, xtype, inplace, forward) in (
136
136
(:plan_fft , :rocfft_transform_type_complex_forward , :false , :true ),
137
137
(:plan_bfft , :rocfft_transform_type_complex_inverse , :false , :false ),
138
138
)
139
- @eval function $f (X:: ROCArray{T, N} , region) where {T <: rocfftComplexes , N}
140
- _inplace = $ (inplace)
141
- _xtype = $ (xtype)
142
- R = length (region)
143
- region = NTuple {R,Int} (region)
144
- pp = get_plan (_xtype, size (X), T, _inplace, region)
145
- return cROCFFTPlan {T,$forward,_inplace,N,R,Nothing} (pp... , X, size (X), _xtype, region, nothing , false , T)
139
+ @eval begin
140
+ # Try to constant-propagate the `region` argument so that its length `R` can be inferred.
141
+ Base. @constprop :aggressive function $f (X:: ROCArray{T, N} , region) where {T <: rocfftComplexes , N}
142
+ R = length (region)
143
+ region = NTuple {R,Int} (region)
144
+ return $ f (X, region)
145
+ end
146
+
147
+ function $f (X:: ROCArray{T, N} , region:: NTuple{R,Int} ) where {T <: rocfftComplexes , N, R}
148
+ _inplace = $ (inplace)
149
+ _xtype = $ (xtype)
150
+ pp = get_plan (_xtype, size (X), T, _inplace, region)
151
+ return cROCFFTPlan {T,$forward,_inplace,N,R,Nothing} (pp... , X, size (X), _xtype, region, nothing , false , T)
152
+ end
146
153
end
147
154
end
148
155
149
- function plan_rfft (X:: ROCArray{T,N} , region) where {T<: rocfftReals ,N}
150
- inplace = false
151
- xtype = rocfft_transform_type_real_forward
156
+ Base. @constprop :aggressive function plan_rfft (X:: ROCArray{T,N} , region) where {T<: rocfftReals ,N}
152
157
R = length (region)
153
158
region = NTuple {R,Int} (region)
159
+ return plan_rfft (X, region)
160
+ end
161
+
162
+ function plan_rfft (X:: ROCArray{T,N} , region:: NTuple{R,Int} ) where {T<: rocfftReals ,N,R}
163
+ inplace = false
164
+ xtype = rocfft_transform_type_real_forward
154
165
pp = get_plan (xtype, size (X), T, inplace, region)
155
- ydims = collect (size (X))
156
- ydims[region[1 ]] = div (ydims[region[1 ]],2 ) + 1
166
+
167
+ xdims = size (X)
168
+ ydims = Base. setindex (xdims, div (xdims[region[1 ]],2 ) + 1 , region[1 ])
157
169
158
170
# The buffer is not needed for real-to-complex (`mul!`),
159
171
# but it’s required for complex-to-real (`ldiv!`).
160
- buffer = ROCArray {complex(T)} (undef, ydims... )
172
+ buffer = ROCArray {complex(T)} (undef, ydims)
161
173
B = typeof (buffer)
162
174
163
- return rROCFFTPlan {T,ROCFFT_FORWARD,inplace,N,R,B} (pp... , X, ( ydims... ,) , xtype, region, buffer, true , T)
175
+ return rROCFFTPlan {T,ROCFFT_FORWARD,inplace,N,R,B} (pp... , X, ydims, xtype, region, buffer, true , T)
164
176
end
165
177
166
- function plan_brfft (X:: ROCArray{T,N} , d:: Integer , region) where {T <: rocfftComplexes , N}
167
- inplace = false
168
- xtype = rocfft_transform_type_real_inverse
178
+ Base. @constprop :aggressive function plan_brfft (X:: ROCArray{T,N} , d:: Integer , region) where {T <: rocfftComplexes , N}
169
179
R = length (region)
170
180
region = NTuple {R,Int} (region)
171
- ydims = collect (size (X))
172
- ydims[region[1 ]] = d
173
- pp = get_plan (xtype, (ydims... ,), T, inplace, region)
181
+ return plan_brfft (X, d, region)
182
+ end
183
+
184
+ function plan_brfft (X:: ROCArray{T,N} , d:: Integer , region:: NTuple{R,Int} ) where {T <: rocfftComplexes , N, R}
185
+ inplace = false
186
+ xtype = rocfft_transform_type_real_inverse
187
+ xdims = size (X)
188
+ ydims = Base. setindex (xdims, d, region[1 ])
189
+ pp = get_plan (xtype, ydims, T, inplace, region)
174
190
175
191
# Buffer to not modify the input in a complex-to-real FFT.
176
192
buffer = ROCArray {T} (undef, size (X))
177
193
B = typeof (buffer)
178
194
179
- return rROCFFTPlan {T,ROCFFT_INVERSE,inplace,N,R,B} (pp... , X, ( ydims... ,) , xtype, region, buffer, false , T)
195
+ return rROCFFTPlan {T,ROCFFT_INVERSE,inplace,N,R,B} (pp... , X, ydims, xtype, region, buffer, false , T)
180
196
end
181
197
182
198
# FIXME : plan_inv methods allocate needlessly (to provide type parameters and normalization function)
0 commit comments