Skip to content

Commit 4ee65df

Browse files
committed
Correct inverse plan computation and caching in test plans
1 parent 3e7d412 commit 4ee65df

File tree

1 file changed

+24
-24
lines changed

1 file changed

+24
-24
lines changed

test/testplans.jl

Lines changed: 24 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -27,21 +27,20 @@ end
2727
function AbstractFFTs.plan_bfft(x::AbstractArray{T}, region; kwargs...) where {T}
2828
return InverseTestPlan{T}(region, size(x))
2929
end
30+
3031
function AbstractFFTs.plan_inv(p::TestPlan{T}) where {T}
3132
unscaled_pinv = InverseTestPlan{T}(p.region, p.sz)
32-
unscaled_pinv.pinv = p
33-
pinv = AbstractFFTs.ScaledPlan(
34-
unscaled_pinv, AbstractFFTs.normalization(T, p.sz, p.region),
35-
)
33+
N = AbstractFFTs.normalization(T, p.sz, p.region)
34+
unscaled_pinv.pinv = AbstractFFTs.ScaledPlan(p, N)
35+
pinv = AbstractFFTs.ScaledPlan(unscaled_pinv, N)
3636
return pinv
3737
end
38-
function AbstractFFTs.plan_inv(p::InverseTestPlan{T}) where {T}
39-
unscaled_pinv = TestPlan{T}(p.region, p.sz)
40-
unscaled_pinv.pinv = p
41-
pinv = AbstractFFTs.ScaledPlan(
42-
unscaled_pinv, AbstractFFTs.normalization(T, p.sz, p.region),
43-
)
44-
return pinv
38+
function AbstractFFTs.plan_inv(pinv::InverseTestPlan{T}) where {T}
39+
unscaled_p = TestPlan{T}(pinv.region, pinv.sz)
40+
N = AbstractFFTs.normalization(T, pinv.sz, pinv.region)
41+
unscaled_p.pinv = AbstractFFTs.ScaledPlan(pinv, N)
42+
p = AbstractFFTs.ScaledPlan(unscaled_p, N)
43+
return p
4544
end
4645

4746
# Just a helper function since forward and backward are nearly identical
@@ -118,22 +117,23 @@ function AbstractFFTs.plan_inv(p::TestRPlan{T,N}) where {T,N}
118117
firstdim = first(p.region)::Int
119118
d = p.sz[firstdim]
120119
sz = ntuple(i -> i == firstdim ? d ÷ 2 + 1 : p.sz[i], Val(N))
120+
_N = AbstractFFTs.normalization(T, p.sz, p.region)
121+
121122
unscaled_pinv = InverseTestRPlan{T}(d, p.region, sz)
122-
unscaled_pinv.pinv = p
123-
pinv = AbstractFFTs.ScaledPlan(
124-
unscaled_pinv, AbstractFFTs.normalization(T, p.sz, p.region),
125-
)
123+
unscaled_pinv.pinv = AbstractFFTs.ScaledPlan(p, _N)
124+
pinv = AbstractFFTs.ScaledPlan(unscaled_pinv, _N)
126125
return pinv
127126
end
128-
function AbstractFFTs.plan_inv(p::InverseTestRPlan{T,N}) where {T,N}
129-
firstdim = first(p.region)::Int
130-
sz = ntuple(i -> i == firstdim ? p.d : p.sz[i], Val(N))
131-
unscaled_pinv = TestRPlan{T}(p.region, sz)
132-
unscaled_pinv.pinv = p
133-
pinv = AbstractFFTs.ScaledPlan(
134-
unscaled_pinv, AbstractFFTs.normalization(T, sz, p.region),
135-
)
136-
return pinv
127+
128+
function AbstractFFTs.plan_inv(pinv::InverseTestRPlan{T,N}) where {T,N}
129+
firstdim = first(pinv.region)::Int
130+
sz = ntuple(i -> i == firstdim ? pinv.d : pinv.sz[i], Val(N))
131+
_N = AbstractFFTs.normalization(T, sz, pinv.region)
132+
133+
unscaled_p = TestRPlan{T}(pinv.region, sz)
134+
unscaled_p.pinv = AbstractFFTs.ScaledPlan(pinv, _N)
135+
p = AbstractFFTs.ScaledPlan(unscaled_p, _N)
136+
return p
137137
end
138138

139139
Base.size(p::TestRPlan) = p.sz

0 commit comments

Comments
 (0)