Skip to content

Commit bb47d68

Browse files
authored
Fix inference of FFT plan creation (#739)
* Make FFT plan construction inferrable * Test inference
1 parent 29aaa16 commit bb47d68

File tree

2 files changed

+42
-26
lines changed

2 files changed

+42
-26
lines changed

src/fft/fft.jl

Lines changed: 37 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -136,47 +136,63 @@ for (f, xtype, inplace, forward) in (
136136
(:plan_fft, :rocfft_transform_type_complex_forward, :false, :true),
137137
(:plan_bfft, :rocfft_transform_type_complex_inverse, :false, :false),
138138
)
139-
@eval function $f(X::ROCArray{T, N}, region) where {T <: rocfftComplexes, N}
140-
_inplace = $(inplace)
141-
_xtype = $(xtype)
142-
R = length(region)
143-
region = NTuple{R,Int}(region)
144-
pp = get_plan(_xtype, size(X), T, _inplace, region)
145-
return cROCFFTPlan{T,$forward,_inplace,N,R,Nothing}(pp..., X, size(X), _xtype, region, nothing, false, T)
139+
@eval begin
140+
# Try to constant-propagate the `region` argument so that its length `R` can be inferred.
141+
Base.@constprop :aggressive function $f(X::ROCArray{T, N}, region) where {T <: rocfftComplexes, N}
142+
R = length(region)
143+
region = NTuple{R,Int}(region)
144+
return $f(X, region)
145+
end
146+
147+
function $f(X::ROCArray{T, N}, region::NTuple{R,Int}) where {T <: rocfftComplexes, N, R}
148+
_inplace = $(inplace)
149+
_xtype = $(xtype)
150+
pp = get_plan(_xtype, size(X), T, _inplace, region)
151+
return cROCFFTPlan{T,$forward,_inplace,N,R,Nothing}(pp..., X, size(X), _xtype, region, nothing, false, T)
152+
end
146153
end
147154
end
148155

149-
function plan_rfft(X::ROCArray{T,N}, region) where {T<:rocfftReals,N}
150-
inplace = false
151-
xtype = rocfft_transform_type_real_forward
156+
Base.@constprop :aggressive function plan_rfft(X::ROCArray{T,N}, region) where {T<:rocfftReals,N}
152157
R = length(region)
153158
region = NTuple{R,Int}(region)
159+
return plan_rfft(X, region)
160+
end
161+
162+
function plan_rfft(X::ROCArray{T,N}, region::NTuple{R,Int}) where {T<:rocfftReals,N,R}
163+
inplace = false
164+
xtype = rocfft_transform_type_real_forward
154165
pp = get_plan(xtype, size(X), T, inplace, region)
155-
ydims = collect(size(X))
156-
ydims[region[1]] = div(ydims[region[1]],2) + 1
166+
167+
xdims = size(X)
168+
ydims = Base.setindex(xdims, div(xdims[region[1]],2) + 1, region[1])
157169

158170
# The buffer is not needed for real-to-complex (`mul!`),
159171
# but it’s required for complex-to-real (`ldiv!`).
160-
buffer = ROCArray{complex(T)}(undef, ydims...)
172+
buffer = ROCArray{complex(T)}(undef, ydims)
161173
B = typeof(buffer)
162174

163-
return rROCFFTPlan{T,ROCFFT_FORWARD,inplace,N,R,B}(pp..., X, (ydims...,), xtype, region, buffer, true, T)
175+
return rROCFFTPlan{T,ROCFFT_FORWARD,inplace,N,R,B}(pp..., X, ydims, xtype, region, buffer, true, T)
164176
end
165177

166-
function plan_brfft(X::ROCArray{T,N}, d::Integer, region) where {T <: rocfftComplexes, N}
167-
inplace = false
168-
xtype = rocfft_transform_type_real_inverse
178+
Base.@constprop :aggressive function plan_brfft(X::ROCArray{T,N}, d::Integer, region) where {T <: rocfftComplexes, N}
169179
R = length(region)
170180
region = NTuple{R,Int}(region)
171-
ydims = collect(size(X))
172-
ydims[region[1]] = d
173-
pp = get_plan(xtype, (ydims...,), T, inplace, region)
181+
return plan_brfft(X, d, region)
182+
end
183+
184+
function plan_brfft(X::ROCArray{T,N}, d::Integer, region::NTuple{R,Int}) where {T <: rocfftComplexes, N, R}
185+
inplace = false
186+
xtype = rocfft_transform_type_real_inverse
187+
xdims = size(X)
188+
ydims = Base.setindex(xdims, d, region[1])
189+
pp = get_plan(xtype, ydims, T, inplace, region)
174190

175191
# Buffer to not modify the input in a complex-to-real FFT.
176192
buffer = ROCArray{T}(undef, size(X))
177193
B = typeof(buffer)
178194

179-
return rROCFFTPlan{T,ROCFFT_INVERSE,inplace,N,R,B}(pp..., X, (ydims...,), xtype, region, buffer, false, T)
195+
return rROCFFTPlan{T,ROCFFT_INVERSE,inplace,N,R,B}(pp..., X, ydims, xtype, region, buffer, false, T)
180196
end
181197

182198
# FIXME: plan_inv methods allocate needlessly (to provide type parameters and normalization function)

test/rocarray/fft.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ function out_of_place(X::AbstractArray{T,N}) where {T <: Complex,N}
1717
fftw_X = fft(X)
1818

1919
dX = ROCArray(X)
20-
p = plan_fft(dX)
20+
p = @inferred plan_fft(dX)
2121
dY = p * dX
2222
@test isapprox(collect(dY), fftw_X; rtol=MYRTOL, atol=MYATOL)
2323
@test X collect(dX)
@@ -37,7 +37,7 @@ function in_place(X::AbstractArray{T,N}) where {T <: Complex,N}
3737
fftw_X = fft(X)
3838

3939
dX = ROCArray(X)
40-
p = plan_fft!(dX)
40+
p = @inferred plan_fft!(dX)
4141
p * dX
4242
@test isapprox(collect(dX), fftw_X; rtol=MYRTOL, atol=MYATOL)
4343

@@ -50,7 +50,7 @@ function batched(X::AbstractArray{T,N}, region) where {T <: Complex,N}
5050
fftw_X = fft(X, region)
5151

5252
dX = ROCArray(X)
53-
p = plan_fft!(dX, region)
53+
p = @inferred plan_fft!(dX, region)
5454
p * dX
5555
@test isapprox(collect(dX), fftw_X; rtol=MYRTOL, atol=MYATOL)
5656

@@ -173,7 +173,7 @@ function out_of_place(X::AbstractArray{T,N}) where {T <: Real,N}
173173
fftw_X = rfft(X)
174174
dX = ROCArray(X)
175175

176-
p = plan_rfft(dX)
176+
p = @inferred plan_rfft(dX)
177177
dY = p * dX
178178
Y = collect(dY)
179179
@test isapprox(Y, fftw_X; rtol=MYRTOL, atol=MYATOL)
@@ -197,7 +197,7 @@ function batched(X::AbstractArray{T,N},region) where {T <: Real,N}
197197
fftw_X = rfft(X,region)
198198
dX = ROCArray(X)
199199

200-
p = plan_rfft(dX, region)
200+
p = @inferred plan_rfft(dX, region)
201201
dY = p * dX
202202
@test isapprox(collect(dY), fftw_X; rtol=MYRTOL, atol=MYATOL)
203203
@test X collect(dX)

0 commit comments

Comments
 (0)