Skip to content

Commit 8ffa7df

Browse files
committed
tests pass
1 parent 1592531 commit 8ffa7df

File tree

2 files changed

+16
-3
lines changed

2 files changed

+16
-3
lines changed

ext/AbstractFFTsForwardDiffExt.jl

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ using AbstractFFTs
44
using AbstractFFTs.LinearAlgebra
55
import ForwardDiff
66
import ForwardDiff: Dual
7-
import AbstractFFTs: Plan, mul!
7+
import AbstractFFTs: Plan, mul!, dualplan, dual2array
88

99

1010
AbstractFFTs.complexfloat(x::AbstractArray{<:Dual}) = AbstractFFTs.complexfloat.(x)
@@ -32,6 +32,7 @@ end
3232

3333
DualPlan(::Type{Dual{Tag,V,N}}, p::Plan{T}) where {Tag,T<:Real,V,N} = DualPlan{Dual{Tag,T,N},typeof(p)}(p)
3434
DualPlan(::Type{Dual{Tag,V,N}}, p::Plan{Complex{T}}) where {Tag,T<:Real,V,N} = DualPlan{Complex{Dual{Tag,T,N}},typeof(p)}(p)
35+
dualplan(D, p) = DualPlan(D, p)
3536
Base.size(p::DualPlan) = Base.tail(size(p.p))
3637
Base.:*(p::DualPlan{DT}, x::AbstractArray{DT}) where DT<:Dual = array2dual(DT, p.p * dual2array(x))
3738
Base.:*(p::DualPlan{Complex{DT}}, x::AbstractArray{Complex{DT}}) where DT<:Dual = array2dual(DT, p.p * dual2array(x))
@@ -48,11 +49,18 @@ end
4849

4950
for plan in (:plan_fft, :plan_ifft, :plan_bfft, :plan_rfft)
5051
@eval begin
51-
AbstractFFTs.$plan(x::AbstractArray{D}, dims=1:ndims(x)) where D<:Dual = DualPlan(D, AbstractFFTs.$plan(dual2array(x), 1 .+ dims))
52-
AbstractFFTs.$plan(x::AbstractArray{<:Complex{D}}, dims=1:ndims(x)) where D<:Dual = DualPlan(D, AbstractFFTs.$plan(dual2array(x), 1 .+ dims))
52+
AbstractFFTs.$plan(x::AbstractArray{D}, dims=1:ndims(x)) where D<:Dual = dualplan(D, AbstractFFTs.$plan(dual2array(x), 1 .+ dims))
53+
AbstractFFTs.$plan(x::AbstractArray{<:Complex{D}}, dims=1:ndims(x)) where D<:Dual = dualplan(D, AbstractFFTs.$plan(dual2array(x), 1 .+ dims))
5354
end
5455
end
5556

5657

58+
for plan in (:plan_irfft, :plan_brfft) # these take an extra argument, only when complex?
59+
@eval begin
60+
AbstractFFTs.$plan(x::AbstractArray{D}, dims=1:ndims(x)) where D<:Dual = dualplan(D, AbstractFFTs.$plan(dual2array(x), 1 .+ dims))
61+
AbstractFFTs.$plan(x::AbstractArray{<:Complex{D}}, d::Integer, mdims=1:ndims(x)) where D<:Dual = dualplan(D, AbstractFFTs.$plan(dual2array(x), d, 1 .+ dims))
62+
end
63+
end
64+
5765

5866
end # module

src/AbstractFFTs.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,15 @@ export fft, ifft, bfft, fft!, ifft!, bfft!,
88
include("definitions.jl")
99
include("TestUtils.jl")
1010

11+
# Create function used by multiple extension as loading order is not guaranteed
12+
function dualplan end
13+
function dual2array end
14+
1115
if !isdefined(Base, :get_extension)
1216
include("../ext/AbstractFFTsChainRulesCoreExt.jl")
1317
include("../ext/AbstractFFTsTestExt.jl")
1418
include("../ext/AbstractFFTsForwardDiffExt.jl")
1519
end
1620

21+
1722
end # module

0 commit comments

Comments
 (0)