Skip to content

Fuse frule #16

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Jan 13, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 11 additions & 3 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,20 +1,28 @@
name = "ForwardDiff2"
uuid = "994df76e-a4c1-5e1f-bd5c-23b9b5303d4f"
authors = ["Yingbo Ma <mayingbo5@gmail.com>"]
authors = ["Yingbo Ma <mayingbo5@gmail.com>", "Shashi Gowda <gowda@mit.edu>"]
version = "0.1.0"

[deps]
Cassette = "7057c7e9-c182-5462-911a-8362d720325c"
ChainRules = "082447d4-558c-5d27-93f4-14fc19e9eca2"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"

[compat]
julia = "1"
Calculus = "0.5.1"
Cassette = "0.3.0"
ChainRules = "0.2.5"
ChainRulesCore = "0.4"
ChainRules = "0.3.1"
ChainRulesCore = "0.5.3"
DiffRules = "1.0.0"
MacroTools = "0.5.3"
NaNMath = "0.3.3"
SafeTestsets = "0.0.1"
SpecialFunctions = "0.9"
StaticArrays = "0.11, 0.12"

[extras]
Expand Down
54 changes: 49 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,60 @@
[![Build Status](https://travis-ci.org/YingboMa/ForwardDiff2.jl.svg?branch=master)](https://travis-ci.org/YingboMa/ForwardDiff2.jl)
[![codecov](https://codecov.io/gh/YingboMa/ForwardDiff2.jl/branch/master/graph/badge.svg)](https://codecov.io/gh/YingboMa/ForwardDiff2.jl)

`ForwardDiff2` = `ForwardDiff.jl` + `ChainRules.jl` + Struct of arrays + `DualCache`
`ForwardDiff2` = `ForwardDiff.jl` + `ChainRules.jl` + Struct of arrays

### Warning!!!: This package is still work-in-progress

User API:
```julia
julia> using ForwardDiff2: D

julia> v = rand(2)
2-element Array{Float64,1}:
0.22260830987887537
0.6397089507287486

julia> D(prod)(v) # gradient
1×2 LinearAlgebra.Adjoint{Float64,Array{Float64,1}}:
0.639709 0.222608

julia> D(cumsum)(v) # Jacobian
2×2 Array{Float64,2}:
1.0 0.0
1.0 1.0

julia> D(D(prod))(v) # Hessian
2×2 LinearAlgebra.Adjoint{Float64,Array{Float64,2}}:
0.0 1.0
1.0 0.0
```

Note that `ForwardDiff2.jl` also works with `ModelingToolkit.jl`:
```julia
julia> using ModelingToolkit

julia> @variables v[1:2]
(Operation[v₁, v₂],)

julia> D(prod)(v) # gradient
1×2 LinearAlgebra.Adjoint{Operation,Array{Operation,1}}:
conj(1v₂ + v₁ * identity(0)) conj(identity(0) * v₂ + v₁ * 1)

julia> D(cumsum)(v) # Jacobian
2×2 Array{Expression,2}:
Constant(1) identity(0)
identity(0) + 1 1 + identity(0)

julia> D(D(prod))(v) # Hessian
2×2 LinearAlgebra.Adjoint{Operation,Array{Operation,2}}:
conj((1 * identity(0) + v₁ * 0) + (1 * identity(0) + v₂ * 0)) conj((identity(0) * identity(0) + v₁ * 0) + (1 * 1 + v₂ * 0))
conj((1 * 1 + v₁ * 0) + (identity(0) * identity(0) + v₂ * 0)) conj((identity(0) * 1 + v₁ * 0) + (identity(0) * 1 + v₂ * 0))
```

Planned features:

- works both on GPU and CPU
- scalar forward mode AD
- vectorized forward mode AD
- [Dual cache](http://docs.juliadiffeq.org/latest/basics/faq.html#I-get-Dual-number-errors-when-I-solve-my-ODE-with-Rosenbrock-or-SDIRK-methods...?-1)
- nested differentiation
- hyper duals (?)
- user-extensible scalar and tensor derivative definitions
- in-place function
- sparsity exploitation (color vector support)
Expand Down
112 changes: 76 additions & 36 deletions src/dual_context.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,63 @@
using Cassette
using ChainRules
using ChainRulesCore
import ChainRulesCore: Wirtinger, Zero
import ChainRulesCore: Zero

# TODO: remove the copy pasted code and add that package
# copyed from SpecializeVarargs.jl, written by @MasonProtter
using MacroTools: MacroTools, splitdef, combinedef, @capture

macro specialize_vararg(n::Int, fdef::Expr)
@assert n > 0

macros = Symbol[]
while fdef.head == :macrocall && length(fdef.args) == 3
push!(macros, fdef.args[1])
fdef = fdef.args[3]
end

d = splitdef(fdef)
args = d[:args][end]
@assert d[:args][end] isa Expr && d[:args][end].head == Symbol("...") && d[:args][end].args[] isa Symbol
args_symbol = d[:args][end].args[]

fdefs = Expr(:block)

for i in 1:n-1
di = deepcopy(d)
pop!(di[:args])
args = Tuple(gensym("arg$j") for j in 1:i)
Ts = Tuple(gensym("T$j") for j in 1:i)

args_with_Ts = ((arg, T) -> :($arg :: $T)).(args, Ts)

di[:whereparams] = (di[:whereparams]..., Ts...)

push!(di[:args], args_with_Ts...)
pushfirst!(di[:body].args, :($args_symbol = $(Expr(:tuple, args...))))
cfdef = combinedef(di)
mcfdef = isempty(macros) ? cfdef : foldr((m,f) -> Expr(:macrocall, m, nothing, f), macros, init=cfdef)
push!(fdefs.args, mcfdef)
end

di = deepcopy(d)
pop!(di[:args])
args = tuple((gensym() for j in 1:n)..., :($(gensym("args"))...))
Ts = Tuple(gensym("T$j") for j in 1:n)

args_with_Ts = (((arg, T) -> :($arg :: $T)).(args[1:end-1], Ts)..., args[end])

di[:whereparams] = (di[:whereparams]..., Ts...)

push!(di[:args], args_with_Ts...)
pushfirst!(di[:body].args, :($args_symbol = $(Expr(:tuple, args...))))

cfdef = combinedef(di)
mcfdef = isempty(macros) ? cfdef : foldr((m,f) -> Expr(:macrocall, m, nothing, f), macros, init=cfdef)
push!(fdefs.args, mcfdef)

esc(fdefs)
end

using Cassette: overdub, Context, nametype, similarcontext

Expand Down Expand Up @@ -30,8 +86,6 @@ end
@inline _partials(::Any, x) = Zero()
@inline _partials(::Tag{T}, d::Dual{Tag{T}}) where T = d.partials

Wirtinger(primal, conjugate) = Wirtinger.(primal, conjugate)

@inline _values(S, xs) = map(x->_value(S, x), xs)
@inline _partialss(S, xs) = map(x->_partials(S, x), xs)

Expand All @@ -48,64 +102,54 @@ Wirtinger(primal, conjugate) = Wirtinger.(primal, conjugate)
end

# actually interesting:

@inline isinteresting(ctx::TaggedCtx, f, a) = anydual(a)
@inline isinteresting(ctx::TaggedCtx, f, a, b) = anydual(a, b)
@inline isinteresting(ctx::TaggedCtx, f, a, b, c) = anydual(a, b, c)
@inline isinteresting(ctx::TaggedCtx, f, a, b, c, d) = anydual(a, b, c, d)
@inline isinteresting(ctx::TaggedCtx, f, args...) = false
@inline isinteresting(ctx::TaggedCtx, f::typeof(Base.show), args...) = false
@inline isinteresting(ctx::TaggedCtx, f, args...) = anydual(args...)
@inline isinteresting(ctx::TaggedCtx, f::Core.Builtin, args...) = false
@inline isinteresting(ctx::TaggedCtx, f::Union{typeof(ForwardDiff2.find_dual),
typeof(ForwardDiff2.anydual)}, args...) = false

@inline function _frule_overdub2(ctx::TaggedCtx{T}, f, args...) where T
@specialize_vararg 4 @inline function _frule_overdub2(ctx::TaggedCtx{T}, f::F, args...) where {T,F}
# Here we can assume that one or more `args` is a Dual with tag
# of type T.

tag = Tag{T}()
# unwrap only duals with the tag T.
vs = _values(tag, args)

# extract the partials only for the current tag
# so we can pass them to the pushforward
ps = _partialss(tag, args)

# default `dself` to `Zero()`
dself = Zero()

# call frule to see if there is a rule for this call:
if ctx.metadata isa Tag
ctx1 = similarcontext(ctx, metadata=oldertag(ctx.metadata))

# we call frule with an older context because the Dual numbers may
# themselves contain Dual numbers that were created in an older context
frule_result = overdub(ctx1, frule, f, vs...)
frule_result = overdub(ctx1, frule, f, vs..., dself, ps...)
else
frule_result = frule(f, vs...)
frule_result = frule(f, vs..., dself, ps...)
end

if frule_result === nothing
# this means there is no frule
# We can't just do f(args...) here because `f` might be
# a closure which closes over a Dual number, hence we call
# recurse. Recurse overdubs the calls inside `f` and not `f` itself

return Cassette.overdub(ctx, f, args...)
else
# this means there exists an frule for this specific call.
# frule_result is then a tuple (val, pushforward) where val
# is the primal result. (Note: this may be Dual numbers but only
# with an older tag)
val, pushforward = frule_result

# extract the partials only for the current tag
# so we can pass them to the pushforward
ps = _partialss(tag, args)

# Call the pushforward to get new partials
# we call it with the older context because the partials
# might themselves be Duals from older contexts
if ctx.metadata isa Tag
ctx1 = similarcontext(ctx, metadata=oldertag(ctx.metadata))
∂s = overdub(ctx1, pushforward, Zero(), ps...)
else
∂s = pushforward(Zero(), ps...)
end
val, ∂s = frule_result

# Attach the new partials to the primal result
# multi-output `f` such as result in the new partials being
# a tuple, we handle both cases:
return if ∂s isa Tuple
map(val, ∂s) do v, ∂
Dual{Tag{T}}(v, ∂)
Expand All @@ -116,7 +160,7 @@ end
end
end

@inline function alternative(ctx::TaggedCtx{T}, f, args...) where {T}
@specialize_vararg 4 @inline function alternative(ctx::TaggedCtx{T}, f::F, args...) where {T,F}
# This method only executes if `args` contains at least 1 Dual
# the question is what is its tag

Expand Down Expand Up @@ -161,10 +205,6 @@ end


##### Inference Hacks
# this makes `log` work by making throw_complex_domainerror inferable, but not really sure why
@inline isinteresting(ctx::TaggedCtx, f::typeof(Core.throw), xs) = true
# add `DualContext` here to avoid ambiguity
@noinline alternative(ctx::Union{DualContext,TaggedCtx}, f::typeof(Core.throw), arg) = throw(arg)

@inline isinteresting(ctx::TaggedCtx, f::typeof(Base.print_to_string), args...) = true
@noinline alternative(ctx::Union{DualContext,TaggedCtx}, f::typeof(Base.print_to_string), args...) = f(args...)
@inline isinteresting(ctx::TaggedCtx, f::Union{typeof(Base.print_to_string),typeof(hash)}, args...) = false
@inline Cassette.overdub(ctx::TaggedCtx, f::Union{typeof(Base.print_to_string),typeof(hash)}, args...) = f(args...)
@inline Cassette.overdub(ctx::TaggedCtx, f::Core.Builtin, args...) = f(args...)
9 changes: 3 additions & 6 deletions src/dualarray.jl
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
using StaticArrays: SVector
using StaticArrays: SVector, StaticArray

partial_type(::Dual{T,V,P}) where {T,V,P} = P

struct DualArray{T,E,M,V<:AbstractArray,D<:AbstractArray} <: AbstractArray{E,M}
data::V
partials::D
function DualArray{T}(v::AbstractArray{E,N}, p::P) where {T,E,N,P<:AbstractArray}
# TODO: non-allocating X?
X = typeof(similar(p, Base.tail(ntuple(_->0, Val(ndims(P))))))
X = p isa StaticArray ? typeof(vec(p)[axes(p, ndims(p))]) : typeof(vec(p))
# we need the eltype of `DualArray` to be `Dual{T,E,X}` as opposed to
# some kind of `view`, because we can convert `SubArray` to `Array` but
# not vise a versa.
Expand Down Expand Up @@ -37,7 +36,7 @@ function Base.print_array(io::IO, da::DualArray)
end

DualArray(a::AbstractArray, b::AbstractArray) = DualArray{typeof(dualtag())}(a, b)
npartials(d::DualArray) = size(d.partials, ndims(d.partials))
npartials(d::DualArray) = (ps = allpartials(d); size(ps, ndims(ps)))
data(d::DualArray) = d.data
allpartials(d::DualArray) = d.partials

Expand All @@ -51,8 +50,6 @@ Base.IndexStyle(d::DualArray) = Base.IndexStyle(data(d))
Base.similar(d::DualArray{T}, ::Type{S}, dims::Dims) where {T, S} = DualArray{T}(similar(data(d)), similar(allpartials(d)))
Base.eachindex(d::DualArray) = eachindex(data(d))

using StaticArrays

Base.@propagate_inbounds _slice(A, i...) = @view A[i..., :]
Base.@propagate_inbounds _slice(A::StaticArray, i...) = A[i..., :]

Expand Down
12 changes: 7 additions & 5 deletions src/dualnumber.jl
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ dualtag() = nothing

@inline partials(d::Dual) = d.partials

@inline npartials(d::Dual) = (ps = d.partials) isa Wirtinger ? 1 : length(ps)
@inline npartials(d::Dual) = (ps = partials(d)) isa ChainRulesCore.AbstractDifferential ? 1 : length(d.partials)

#####################
# Generic Functions #
Expand Down Expand Up @@ -128,11 +128,13 @@ function Base.write(io::IO, d::Dual)
write(io, partials(d))
end

@inline Base.zero(d::Dual) = zero(typeof(d))
@inline Base.zero(::Type{Dual{T,V,P}}) where {T,V,P} = Dual{T}(zero(V), zero(P))
@inline Base.zero(d::Dual{T}) where T = Dual{T}(zero(value(d)), zero(partials(d)))
#@inline Base.zero(d::Dual) = zero(typeof(d))
#@inline Base.zero(::Type{Dual{T,V,P}}) where {T,V,P} = Dual{T}(zero(V), zero(P))

@inline Base.one(d::Dual) = one(typeof(d))
@inline Base.one(::Type{Dual{T,V,P}}) where {T,V,P} = Dual{T}(one(V), zero(P))
@inline Base.one(d::Dual{T}) where T = Dual{T}(one(value(d)), zero(partials(d)))
#@inline Base.one(d::Dual) = one(typeof(d))
#@inline Base.one(::Type{Dual{T,V,P}}) where {T,V,P} = Dual{T}(one(V), zero(P))

@inline Random.rand(rng::AbstractRNG, d::Dual) = rand(rng, value(d))
@inline Random.rand(::Type{Dual{T,V,P}}) where {T,V,P} = Dual{T}(rand(V), zero(P))
Expand Down
4 changes: 4 additions & 0 deletions test/api.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,8 @@ using StaticArrays
# Hessian
@test D(D(x->x[1]^x[2] + x[3]^3 + x[3]*x[2]*x[1]))(@SVector[1,2,3]) === @SMatrix [2 4 2; 4 0 1; 2 1 18.]
@test D(D(x->x[1]^x[2] + x[3]^3 + x[3]*x[2]*x[1]))([1,2,3]) == [2 4 2; 4 0 1; 2 1 18.]
# inference
@inferred D(x->exp(x) + x^x + cos(x) + tan(x) + 2^x)(1)
# broken due to `Core._apply`
@test_broken @inferred D(x->exp(x) + x^x + cos(x) + tan(x) + 2^x + log(cos(x)) + sec(pi*x) - angle(x) + one(x) / log1p(sin(x)))(1)
end
14 changes: 7 additions & 7 deletions test/dualtest.jl
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ _div_partials(a, b, aval, bval) = _mul_partials(a, b, inv(bval), -(aval / (bval*

const Partials{N,V} = SVector{N,V}

for N in (0,3), M in (0,4), V in (Int, Float32)
for N in (3), M in (4), V in (Int, Float32)
println(" ...testing Dual{..,$V,$N} and Dual{..,Dual{..,$V,$M},$N}")


Expand Down Expand Up @@ -334,13 +334,13 @@ for N in (0,3), M in (0,4), V in (Int, Float32)
# Multiplication #
#----------------#

@test @drun1(FDNUM * FDNUM2) === Dual{Tag1}(value(FDNUM) * value(FDNUM2), _mul_partials(partials(FDNUM), partials(FDNUM2), value(FDNUM2), value(FDNUM)))
@test dual_isapprox(@drun1(FDNUM * FDNUM2), Dual{Tag1}(value(FDNUM) * value(FDNUM2), _mul_partials(partials(FDNUM), partials(FDNUM2), value(FDNUM2), value(FDNUM))))
@test @drun1(FDNUM * PRIMAL) === Dual{Tag1}(value(FDNUM) * PRIMAL, partials(FDNUM) * PRIMAL)
@test @drun1(PRIMAL * FDNUM) === Dual{Tag1}(value(FDNUM) * PRIMAL, partials(FDNUM) * PRIMAL)

@test @drun2(NESTED_FDNUM * NESTED_FDNUM2) === @drun1 Dual{Tag2}(value(NESTED_FDNUM) * value(NESTED_FDNUM2), _mul_partials(partials(NESTED_FDNUM), partials(NESTED_FDNUM2), value(NESTED_FDNUM2), value(NESTED_FDNUM)))
@test @drun2(NESTED_FDNUM * PRIMAL) === @drun1 Dual{Tag2}(value(NESTED_FDNUM) * PRIMAL, partials(NESTED_FDNUM) * PRIMAL)
@test @drun2(PRIMAL * NESTED_FDNUM) === @drun1 Dual{Tag2}(value(NESTED_FDNUM) * PRIMAL, partials(NESTED_FDNUM) * PRIMAL)
@test_broken @drun2(NESTED_FDNUM * PRIMAL) === @drun1 Dual{Tag2}(value(NESTED_FDNUM) * PRIMAL, partials(NESTED_FDNUM) * PRIMAL)
@test_broken @drun2(PRIMAL * NESTED_FDNUM) === @drun1 Dual{Tag2}(value(NESTED_FDNUM) * PRIMAL, partials(NESTED_FDNUM) * PRIMAL)

# Division #
#----------#
Expand All @@ -362,7 +362,7 @@ for N in (0,3), M in (0,4), V in (Int, Float32)
@test dual_isapprox(@drun1(PRIMAL / FDNUM), dual1(PRIMAL / value(FDNUM), (-(PRIMAL) / value(FDNUM)^2) * partials(FDNUM)))

@test dual_isapprox(@drun2(NESTED_FDNUM / NESTED_FDNUM2), @drun1 Dual{Tag2}(value(NESTED_FDNUM) / value(NESTED_FDNUM2), _div_partials(partials(NESTED_FDNUM), partials(NESTED_FDNUM2), value(NESTED_FDNUM), value(NESTED_FDNUM2))))
@test dual_isapprox(@drun2(NESTED_FDNUM / PRIMAL), @drun1 Dual{Tag2}(value(NESTED_FDNUM) / PRIMAL, partials(NESTED_FDNUM) / PRIMAL))
@test_broken dual_isapprox(@drun2(NESTED_FDNUM / PRIMAL), @drun1 Dual{Tag2}(value(NESTED_FDNUM) / PRIMAL, partials(NESTED_FDNUM) / PRIMAL))
@test dual_isapprox(@drun2(PRIMAL / NESTED_FDNUM), @drun1 Dual{Tag2}(PRIMAL / value(NESTED_FDNUM), (-(PRIMAL) / value(NESTED_FDNUM)^2) * partials(NESTED_FDNUM)))

# Exponentiation #
Expand Down Expand Up @@ -399,15 +399,15 @@ for N in (0,3), M in (0,4), V in (Int, Float32)
if V != Int
for (M, f, arity) in DiffRules.diffrules()
in(f, (:hankelh1, :hankelh1x, :hankelh2, :hankelh2x, :/, :rem2pi)) && continue
#println(" ...auto-testing $(M).$(f) with $arity arguments")
println(" ...auto-testing $(M).$(f) with $arity arguments")
if arity == 1
deriv = DiffRules.diffrule(M, f, :x)
modifier = in(f, (:asec, :acsc, :asecd, :acscd, :acosh, :acoth)) ? one(V) : zero(V)
@eval begin
x = rand() + $modifier
dx = dualrun(()->$M.$f(Dual(x, one(x))))
@dtest value(dx) == $M.$f(x)
@dtest partials(dx)[1] == $deriv
@dtest partials(dx)[1] $deriv
end
elseif arity == 2
derivs = DiffRules.diffrule(M, f, :x, :y)
Expand Down