-
Notifications
You must be signed in to change notification settings - Fork 35
Add ForwardDiff extension #138
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
dlfivefifty
wants to merge
19
commits into
master
Choose a base branch
from
dl/ForwardDiffExt
base: master
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+148
−10
Open
Changes from 9 commits
Commits
Show all changes
19 commits
Select commit
Hold shift + click to select a range
2b60a83
Add ForwardDiff extension
dlfivefifty 8a88755
add tests
dlfivefifty 4977633
add tests
dlfivefifty 1695c7e
Add plan_mul to capture partial interface implementation
dlfivefifty 6028c0e
Add DualPlan
dlfivefifty cc6e3d8
revert definitions
dlfivefifty 1592531
Revert TestPlans
dlfivefifty 8ffa7df
tests pass
dlfivefifty 4943167
Update AbstractFFTsForwardDiffExt.jl
dlfivefifty 35f794b
Update Project.toml
dlfivefifty a516bd0
Overload only _fftfloat
dlfivefifty 1f5c6a3
Merge branch 'dl/ForwardDiffExt' of https://github.com/JuliaMath/Abst…
dlfivefifty a968cd1
Only load/test ForwardDiff on versions that support extensions
dlfivefifty 76b776e
Update abstractfftsforwarddiff.jl
dlfivefifty 81168f3
Update src/AbstractFFTs.jl
dlfivefifty cde9caa
add complex tests
dlfivefifty aff09fa
Merge branch 'dl/ForwardDiffExt' of https://github.com/JuliaMath/Abst…
dlfivefifty 2ea43f3
Generalise dual2array/array2dual for strided
dlfivefifty 4fe2464
Update abstractfftsforwarddiff.jl
dlfivefifty File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
module AbstractFFTsForwardDiffExt | ||
|
||
using AbstractFFTs | ||
using AbstractFFTs.LinearAlgebra | ||
import ForwardDiff | ||
import ForwardDiff: Dual | ||
import AbstractFFTs: Plan, mul!, dualplan, dual2array | ||
|
||
|
||
AbstractFFTs.complexfloat(x::AbstractArray{<:Dual}) = AbstractFFTs.complexfloat.(x) | ||
AbstractFFTs.complexfloat(d::Dual{T,V,N}) where {T,V,N} = convert(Dual{T,float(V),N}, d) + 0im | ||
|
||
AbstractFFTs.realfloat(x::AbstractArray{<:Dual}) = AbstractFFTs.realfloat.(x) | ||
AbstractFFTs.realfloat(d::Dual{T,V,N}) where {T,V,N} = convert(Dual{T,float(V),N}, d) | ||
dlfivefifty marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
dual2array(x::Array{<:Dual{Tag,T}}) where {Tag,T} = reinterpret(reshape, T, x) | ||
dual2array(x::Array{<:Complex{<:Dual{Tag, T}}}) where {Tag,T} = complex.(dual2array(real(x)), dual2array(imag(x))) | ||
array2dual(DT::Type{<:Dual}, x::Array{T}) where T = reinterpret(reshape, DT, real(x)) | ||
array2dual(DT::Type{<:Dual}, x::Array{<:Complex{T}}) where T = complex.(array2dual(DT, real(x)), array2dual(DT, imag(x))) | ||
|
||
|
||
######## | ||
# DualPlan | ||
# represents a plan acting on dual numbers. We wrap a plan acting on a higher dimensional tensor | ||
# as an array of duals can be reinterpreted as a higher dimensional array. | ||
# This allows standard FFTW plans to act on arrays of duals. | ||
##### | ||
struct DualPlan{T,P} <: Plan{T} | ||
p::P | ||
DualPlan{T,P}(p) where {T,P} = new(p) | ||
end | ||
|
||
DualPlan(::Type{Dual{Tag,V,N}}, p::Plan{T}) where {Tag,T<:Real,V,N} = DualPlan{Dual{Tag,T,N},typeof(p)}(p) | ||
DualPlan(::Type{Dual{Tag,V,N}}, p::Plan{Complex{T}}) where {Tag,T<:Real,V,N} = DualPlan{Complex{Dual{Tag,T,N}},typeof(p)}(p) | ||
dualplan(D, p) = DualPlan(D, p) | ||
Base.size(p::DualPlan) = Base.tail(size(p.p)) | ||
Base.:*(p::DualPlan{DT}, x::AbstractArray{DT}) where DT<:Dual = array2dual(DT, p.p * dual2array(x)) | ||
Base.:*(p::DualPlan{Complex{DT}}, x::AbstractArray{Complex{DT}}) where DT<:Dual = array2dual(DT, p.p * dual2array(x)) | ||
|
||
function LinearAlgebra.mul!(y::AbstractArray{<:Dual}, p::DualPlan, x::AbstractArray{<:Dual}) | ||
LinearAlgebra.mul!(dual2array(y), p.p, dual2array(x)) # even though `Dual` are immutable, when in an `Array` they can be modified. | ||
y | ||
end | ||
|
||
function LinearAlgebra.mul!(y::AbstractArray{<:Complex{<:Dual}}, p::DualPlan, x::AbstractArray{<:Union{Dual,Complex{<:Dual}}}) | ||
copyto!(y, p*x) # Complex duals cannot be reinterpret in-place | ||
end | ||
|
||
|
||
for plan in (:plan_fft, :plan_ifft, :plan_bfft, :plan_rfft) | ||
@eval begin | ||
AbstractFFTs.$plan(x::AbstractArray{D}, dims=1:ndims(x)) where D<:Dual = dualplan(D, AbstractFFTs.$plan(dual2array(x), 1 .+ dims)) | ||
AbstractFFTs.$plan(x::AbstractArray{<:Complex{D}}, dims=1:ndims(x)) where D<:Dual = dualplan(D, AbstractFFTs.$plan(dual2array(x), 1 .+ dims)) | ||
end | ||
end | ||
|
||
|
||
for plan in (:plan_irfft, :plan_brfft) # these take an extra argument, only when complex? | ||
@eval begin | ||
AbstractFFTs.$plan(x::AbstractArray{D}, dims=1:ndims(x)) where D<:Dual = dualplan(D, AbstractFFTs.$plan(dual2array(x), 1 .+ dims)) | ||
AbstractFFTs.$plan(x::AbstractArray{<:Complex{D}}, d::Integer, dims=1:ndims(x)) where D<:Dual = dualplan(D, AbstractFFTs.$plan(dual2array(x), d, 1 .+ dims)) | ||
end | ||
end | ||
|
||
|
||
end # module |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
using AbstractFFTs | ||
using ForwardDiff | ||
using Test | ||
using ForwardDiff: Dual, partials, value | ||
|
||
# Needed until https://github.com/JuliaDiff/ForwardDiff.jl/pull/732 is merged | ||
complexpartials(x, k) = partials(real(x), k) + im*partials(imag(x), k) | ||
|
||
@testset "ForwardDiff extension tests" begin | ||
x1 = Dual.(1:4.0, 2:5, 3:6) | ||
|
||
@test AbstractFFTs.complexfloat(x1)[1] === AbstractFFTs.complexfloat(x1[1]) === Dual(1.0, 2.0, 3.0) + 0im | ||
@test AbstractFFTs.realfloat(x1)[1] === AbstractFFTs.realfloat(x1[1]) === Dual(1.0, 2.0, 3.0) | ||
|
||
@test fft(x1, 1)[1] isa Complex{<:Dual} | ||
|
||
@testset "$f" for f in (fft, ifft, rfft, bfft) | ||
@test value.(f(x1)) == f(value.(x1)) | ||
@test complexpartials.(f(x1), 1) == f(partials.(x1, 1)) | ||
@test complexpartials.(f(x1), 2) == f(partials.(x1, 2)) | ||
end | ||
|
||
@test ifft(fft(x1)) ≈ x1 | ||
@test irfft(rfft(x1), length(x1)) ≈ x1 | ||
@test brfft(rfft(x1), length(x1)) ≈ 4x1 | ||
|
||
f = x -> real(fft([x; 0; 0])[1]) | ||
@test ForwardDiff.derivative(f,0.1) ≈ 1 | ||
|
||
r = x -> real(rfft([x; 0; 0])[1]) | ||
@test ForwardDiff.derivative(r,0.1) ≈ 1 | ||
|
||
|
||
n = 100 | ||
θ = range(0,2π; length=n+1)[1:end-1] | ||
# emperical from Mathematical | ||
@test ForwardDiff.derivative(ω -> fft(exp.(ω .* cos.(θ)))[1]/n, 1) ≈ 0.565159103992485 | ||
|
||
# c = x -> dct([x; 0; 0])[1] | ||
# @test derivative(c,0.1) ≈ 1 | ||
|
||
@testset "matrix" begin | ||
A = x1 * (1:10)' | ||
@test value.(fft(A)) == fft(value.(A)) | ||
@test complexpartials.(fft(A), 1) == fft(partials.(A, 1)) | ||
@test complexpartials.(fft(A), 2) == fft(partials.(A, 2)) | ||
|
||
@test value.(fft(A, 1)) == fft(value.(A), 1) | ||
@test complexpartials.(fft(A, 1), 1) == fft(partials.(A, 1), 1) | ||
@test complexpartials.(fft(A, 1), 2) == fft(partials.(A, 2), 1) | ||
|
||
@test value.(fft(A, 2)) == fft(value.(A), 2) | ||
@test complexpartials.(fft(A, 2), 1) == fft(partials.(A, 1), 2) | ||
@test complexpartials.(fft(A, 2), 2) == fft(partials.(A, 2), 2) | ||
end | ||
|
||
c1 = complex.(x1) | ||
@test mul!(similar(c1), plan_fft(x1), x1) == fft(x1) | ||
@test mul!(similar(c1), plan_fft(c1), c1) == fft(c1) | ||
end |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.