Skip to content

Commit 4d823df

Browse files
authored
Merge pull request #13 from YingboMa/s/api
User API & DualArray
2 parents 2fd8a91 + 3d414ec commit 4d823df

File tree

10 files changed

+113
-152
lines changed

10 files changed

+113
-152
lines changed

Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ version = "0.1.0"
77
Cassette = "7057c7e9-c182-5462-911a-8362d720325c"
88
ChainRules = "082447d4-558c-5d27-93f4-14fc19e9eca2"
99
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
10+
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1011
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1112
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
1213

src/ForwardDiff2.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ include("custom_dispatch.jl")
55
include("tag.jl")
66
include("dualarray.jl")
77
include("dual_context.jl")
8-
include("jacobian.jl")
8+
include("api.jl")
99

1010
# Experimental
1111
#include("aosoa.jl")

src/api.jl

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
using StaticArrays: StaticArray, SMatrix, SVector
2+
using LinearAlgebra: Diagonal, I
3+
4+
extract_diffresult(xs::AbstractArray{<:Number}) = xs
5+
# need to optimize
6+
extract_diffresult(xs) = hcat(xs...)'
7+
function extract_diffresult(xs::StaticVector{<:StaticArray})
8+
tup = reduce((x,y)->tuple(x..., y...), map(x->x.data, xs.data))
9+
SMatrix{length(xs), length(xs[1])}(tup)
10+
end
11+
extract_diffresult(xs::AbstractMatrix{<:Number}) = xs
12+
extract_diffresult(xs::AbstractVector{<:Number}) = xs'
13+
14+
allpartials(xs) = map(partials, xs)
15+
16+
function seed(v::SVector{N}) where N
17+
SMatrix{N,N,eltype(v)}(I)
18+
end
19+
20+
function seed(v)
21+
Matrix(Diagonal(map(one, v)))
22+
end
23+
24+
function D(f)
25+
# grad
26+
function deriv(arg::AbstractArray)
27+
# always chunk
28+
darr = dualrun(()->DualArray(arg, seed(arg)))
29+
res = dualrun(()->f(darr))
30+
diffres = extract_diffresult(allpartials(res))
31+
return diffres
32+
end
33+
# scalar
34+
function deriv(arg)
35+
dualrun() do
36+
dualized = map(x->Dual(x, one(x)), arg)
37+
res = f(dualized)
38+
return map(partials, res)
39+
end
40+
end
41+
return deriv
42+
end
43+
44+
#=
45+
# scalar case: f: R -> something
46+
D(sin)(1.0)
47+
D(x->[x, x^2])(3)
48+
49+
# gradient case: f: R^n -> R
50+
D(sum)([1,2,3])
51+
52+
# Jacobian case: f: R^n -> R^m
53+
=#

src/dual_context.jl

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -47,11 +47,6 @@ Wirtinger(primal, conjugate) = Wirtinger.(primal, conjugate)
4747
Core._apply(alternative, (ctx, g), args...)
4848
end
4949

50-
# this makes `log` work by making throw_complex_domainerror inferable, but not really sure why
51-
@inline isinteresting(ctx::TaggedCtx, f::typeof(Core.throw), xs) = true
52-
# add `DualContext` here to avoid ambiguity
53-
@noinline alternative(ctx::Union{DualContext,TaggedCtx}, f::typeof(Core.throw), arg) = throw(arg)
54-
5550
# actually interesting:
5651

5752
@inline isinteresting(ctx::TaggedCtx, f, a) = anydual(a)
@@ -163,3 +158,13 @@ for pred in BINARY_PREDICATES
163158
$pred(vx, vy)
164159
end
165160
end
161+
162+
163+
##### Inference Hacks
164+
# this makes `log` work by making throw_complex_domainerror inferable, but not really sure why
165+
@inline isinteresting(ctx::TaggedCtx, f::typeof(Core.throw), xs) = true
166+
# add `DualContext` here to avoid ambiguity
167+
@noinline alternative(ctx::Union{DualContext,TaggedCtx}, f::typeof(Core.throw), arg) = throw(arg)
168+
169+
@inline isinteresting(ctx::TaggedCtx, f::typeof(Base.print_to_string), args...) = true
170+
@noinline alternative(ctx::Union{DualContext,TaggedCtx}, f::typeof(Base.print_to_string), args...) = f(args...)

src/dualarray.jl

Lines changed: 34 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,23 @@
11
using StaticArrays: SVector
2+
3+
partial_type(::Dual{T,V,P}) where {T,V,P} = P
24
# TODO: Tagging?
35
# TODO: Integrate better with SVector. Maybe even use SIMD.jl?
46

5-
struct DualArray{T,E,M,D<:AbstractArray,I} <: AbstractArray{E,M}
6-
data::D
7-
DualArray{T,I}(a::AbstractArray{E,N}) where {T,I,E,N} = new{T,E,N-1,typeof(a),I}(a)
7+
struct DualArray{T,E,M,V<:AbstractArray,D<:AbstractArray} <: AbstractArray{E,M}
8+
data::V
9+
partials::D
10+
function DualArray{T}(v::AbstractArray{E,N}, p::P) where {T,E,N,P<:AbstractArray}
11+
# TODO: non-allocating X?
12+
X = typeof(similar(p, Base.tail(ntuple(_->0, Val(ndims(P))))))
13+
# we need the eltype of `DualArray` to be `Dual{T,E,X}` as opposed to
14+
# some kind of `view`, because we can convert `SubArray` to `Array` but
15+
# not vise a versa.
16+
#
17+
# We need that to differentiate through the following code
18+
# `(foo(x::AbstractArray{T})::T) where {T} = x[1]`
19+
return new{T,Dual{T,E,X},N,typeof(v),typeof(p)}(v, p)
20+
end
821
end
922

1023
###
@@ -18,91 +31,39 @@ function Base.print_array(io::IO, da::DualArray)
1831
Base.printstyled(io, "Primals:\n", bold=false, color=2)
1932
prev_params = io isa IOContext ? io.dict : ()
2033
ioc = IOContext(io, prev_params..., sz)
21-
Base.print_array(ioc, value(da))
34+
Base.print_array(ioc, data(da))
2235
Base.println(io)
23-
for i=1:npartials(da)
24-
Base.printstyled(io,"Partials($i):\n", bold=false, color=3)
25-
Base.print_array(ioc, getindex.(partials.(da), i))
26-
i !== npartials(da) && Base.println(io)
27-
end
36+
Base.printstyled(io,"Partials:\n", bold=false, color=3)
37+
Base.print_array(ioc, allpartials(da))
2838
return nothing
2939
end
3040

31-
DualArray(a::AbstractArray) = DualArray{Nothing,size(a, ndims(a))-1}(a)
32-
npartials(d::DualArray{T,E,M,D,I}) where {T,E,M,D,I} = I
33-
tagtype(::Type{<:DualArray{T}}) where {T} = T
34-
tagtype(::T) where {T<:DualArray} = tagtype(T)
41+
DualArray(a::AbstractArray, b::AbstractArray) = DualArray{typeof(dualtag())}(a, b)
42+
npartials(d::DualArray) = size(d.partials, ndims(d.partials))
3543
data(d::DualArray) = d.data
44+
allpartials(d::DualArray) = d.partials
3645

3746
###
3847
### Array interface
3948
###
4049

41-
Base.eltype(d::DualArray{T,E}) where {T,E} = Dual{T,E,npartials(d)}
42-
droplast(d::Tuple) = d |> reverse |> Base.tail |> reverse
43-
Base.size(d::DualArray) = size(data(d)) |> droplast
44-
Base.size(d::DualArray, i) = i <= ndims(d) ? size(data(d), i) : 1
50+
#droplast(d::Tuple) = d |> reverse |> Base.tail |> reverse
51+
Base.size(d::DualArray) = size(data(d))
4552
Base.IndexStyle(d::DualArray) = Base.IndexStyle(data(d))
46-
Base.strides(d::DualArray) = strides(data(d)) |> droplast
47-
Base.similar(d::DualArray{T}, ::Type{S}, dims::Dims) where {T,S} = DualArray{T,npartials(d)}(similar(data(d), S, (dims..., npartials(d)+1)))
53+
Base.similar(d::DualArray{T}, ::Type{S}, dims::Dims) where {T, S} = DualArray{T}(similar(data(d)), similar(allpartials(d)))
54+
Base.eachindex(d::DualArray) = eachindex(data(d))
4855

49-
###
50-
### Broadcast interface
51-
###
56+
using StaticArrays
5257

53-
using Base.Broadcast: Broadcasted, BroadcastStyle, AbstractArrayStyle, DefaultArrayStyle
58+
Base.@propagate_inbounds _slice(A, i...) = @view A[i..., :]
59+
Base.@propagate_inbounds _slice(A::StaticArray, i...) = A[i..., :]
5460

55-
struct DualStyle{M,T,I,D} <: AbstractArrayStyle{M}
61+
Base.@propagate_inbounds function Base.getindex(d::DualArray{T}, i::Int...) where {T}
62+
return Dual{T}(data(d)[i...], _slice(allpartials(d), i...))
5663
end
5764

58-
Base.Broadcast.BroadcastStyle(::Type{<:DualArray{T,E,M,D,I}}) where {T,E,M,D,I} = DualStyle{M,T,I,typeof(Base.Broadcast.BroadcastStyle(D))}()
59-
function Base.similar(bc::Broadcasted{<:DualStyle{M,T,I,D}}, ::Type{E}) where {M,T,I,D,E}
60-
if E <: Dual
61-
V = valtype(E)
62-
arr = DualArray{T,I}(similar(Array{V}, (axes(bc)..., Base.OneTo(I+1)))) # TODO: work with arbitrary array types. Maybe use `ArrayInterface.jl`?
63-
else
64-
bc′ = convert(Broadcasted{D}, bc)
65-
arr = Base.similar(bc′, E)
66-
end
67-
return arr
68-
end
69-
Base.BroadcastStyle(::DualStyle{M,T,I,D}, ::DualStyle{M,T,I,V}) where {M,T,I,D,V} = DualStyle{M,T,I,typeof(Base.BroadcastStyle(D(), V()))}()
70-
Base.BroadcastStyle(::DualStyle{M,T,I,D}, B::BroadcastStyle) where {M,T,I,D} = DualStyle{M,T,I,typeof(Base.BroadcastStyle(D(), B))}()
71-
Base.BroadcastStyle(::DualStyle{M,T,I,D}, B::DefaultArrayStyle) where {M,T,I,D} = DualStyle{M,T,I,typeof(Base.BroadcastStyle(D(), B))}()
72-
73-
function value(d::DualArray)
74-
n = ndims(d)
75-
dd = data(d)
76-
return @view dd[ntuple(_ -> Colon(), Val(n))..., 1]
77-
end
78-
79-
function partials(d::DualArray)
80-
n = ndims(d)
81-
dd = data(d)
82-
return @view dd[ntuple(_ -> Colon(), Val(n))..., 2:end]
83-
end
84-
85-
Base.eachindex(d::DualArray) = eachindex(@view data(d)[:, 1])
86-
87-
Base.@propagate_inbounds function Base.getindex(d::DualArray, i::Int...)
88-
dd = data(d)
89-
# TODO: do something different if dd is not Linear index style
90-
ii = LinearIndices(size(d))[i...]
91-
val = dd[ii]
92-
slice_len = length(d)
93-
parts = ntuple(j->dd[j * slice_len + ii], Val(npartials(d)))
94-
return Dual(val, SVector(parts))
95-
end
96-
97-
Base.@propagate_inbounds function Base.setindex!(d::DualArray, dual, i::Int...)
98-
dd = data(d)
99-
ii = LinearIndices(size(d))[i...]
100-
dd[ii] = value(dual)
101-
102-
slice_len = length(d)
103-
ps = partials(dual)
104-
for j = 1:npartials(d)
105-
dd[j * slice_len + ii] = ps[j]
106-
end
65+
Base.@propagate_inbounds function Base.setindex!(d::DualArray{T}, dual::Dual{T}, i::Int...) where {T}
66+
data(d)[i...] = value(dual)
67+
allpartials(d)[i..., :] .= partials(dual)
10768
return dual
10869
end

src/jacobian.jl

Lines changed: 0 additions & 28 deletions
This file was deleted.

test/api.jl

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
using Test
2+
using ForwardDiff2: D
3+
using StaticArrays
4+
5+
@testset begin
6+
@test D(sin)(1.0) == cos(1.0)
7+
@test D(x->[x, x^2])(3) == [1, 6]
8+
@test D(sum)([1,2,3]) == ones(3)'
9+
@test D(x->@SVector([x[1]^x[2], x[3]^3, x[3]*x[2]*x[1]]))(@SVector[1,2,3.]) === @SMatrix [2.0 0 0; 0 0 27; 6 3 2]
10+
@test D(cumsum)(@SVector([1,2,3])) == @SMatrix [1 0 0; 1 1 0; 1 1 1]
11+
@test D(cumsum)([1,2,3]) == [1 0 0; 1 1 0; 1 1 1]
12+
@test D(x->@SVector([x[1], x[2]]))(@SVector([1,2,3])) === @SMatrix [1 0 0; 0 1 0]
13+
end

test/dualarray.jl

Lines changed: 0 additions & 35 deletions
This file was deleted.

test/jacobian.jl

Lines changed: 0 additions & 8 deletions
This file was deleted.

test/runtests.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
using SafeTestsets
2+
@time @safetestset "API Tests" begin include("api.jl") end
23
@time @safetestset "Nested Differentiation Tests" begin include("nested.jl") end
34
@time @safetestset "Dual Tests" begin include("dualtest.jl") end
4-
@time @safetestset "DualArray Tests" begin include("dualarray.jl") end
5-
@time @safetestset "Jacobian Tests" begin include("jacobian.jl") end

0 commit comments

Comments
 (0)