Skip to content

Commit be091aa

Browse files
committed
Initial pass at switching CompositeBundle over to TaylorTangentBundle
1 parent 7c0641c commit be091aa

File tree

3 files changed

+94
-5
lines changed

3 files changed

+94
-5
lines changed

src/stage1/forward.jl

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,18 @@ function shuffle_down(b::CompositeBundle{N, B}) where {N, B}
5151
z
5252
end
5353

54+
function shuffle_up(r::TaylorBundle{1, Tuple{B1,B2}}) where {B1,B2}
55+
z₀ = primal(r)[1]
56+
z₁ = partial(r, 1)[1]
57+
z₂ = primal(r)[2]
58+
z₁₂ = partial(r, 1)[2]
59+
if z₁ == z₂
60+
return TaylorBundle{2}(z₀, (z₁, z₁₂))
61+
else
62+
return ExplicitTangentBundle{2}(z₀, (z₁, z₂, z₁₂))
63+
end
64+
end
65+
5466
function shuffle_up(r::CompositeBundle{1})
5567
z₀ = primal(r.tup[1])
5668
z₁ = partial(r.tup[1], 1)
@@ -76,6 +88,7 @@ isswifty(::UniformBundle) = true
7688
isswifty(b::CompositeBundle) = all(isswifty, b.tup)
7789
isswifty(::Any) = false
7890

91+
#TODO: port this to TaylorTangent:
7992
function shuffle_up(r::CompositeBundle{N}) where {N}
8093
a, b = r.tup
8194
if isswifty(a) && isswifty(b) && taylor_compatible(a, b)

src/stage1/recurse_fwd.jl

Lines changed: 37 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,47 @@ struct ∂vararg{N}; end
44

55
(::∂vararg{N})() where {N} = ZeroBundle{N}(())
66
function (::∂vararg{N})(a::AbstractTangentBundle{N}...) where N
7-
CompositeBundle{N, Tuple{map(x->basespace(typeof(x)), a)...}}(a)
7+
B = Tuple{map(x->basespace(Core.Typeof(x)), a)...}
8+
return (∂☆new{N}())(B, a...)
89
end
910

1011
struct ∂☆new{N}; end
1112

12-
(::∂☆new{N})(B::Type, a::AbstractTangentBundle{N}...) where {N} =
13-
CompositeBundle{N, B}(a)
13+
# we split out the 1st order derivative as a special case for performance
14+
# but the nth order case does also work for this
15+
function (::∂☆new{1})(B::Type, xs::AbstractTangentBundle{1}...)
16+
primal_args = map(primal, xs)
17+
the_primal = B <: Tuple ? B(primal_args) : B(primal_args...)
18+
19+
tangent_tup = map(x->partial(x, 1), xs)
20+
the_partial = if B<:Tuple
21+
Tangent{B, typeof(tangent_tup)}(tangent_tup)
22+
else
23+
names = fieldnames(B)
24+
tangent_nt = NamedTuple{names}(tangent_tup)
25+
Tangent{B, typeof(tangent_nt)}(tangent_nt)
26+
end
27+
return TaylorBundle{1, B}(the_primal, (the_partial,))
28+
end
29+
30+
function (::∂☆new{N})(B::Type, xs::AbstractTangentBundle{N}...) where {N}
31+
primal_args = map(primal, xs)
32+
the_primal = B <: Tuple ? B(primal_args) : B(primal_args...)
33+
34+
the_partials = ntuple(Val{N}()) do ii
35+
iith_order_type = ii==1 ? B : Any # the type of the higher order tangents isn't worth tracking
36+
tangent_tup = map(x->partial(x, ii), xs)
37+
tangent = if B<:Tuple
38+
Tangent{iith_order_type, typeof(tangent_tup)}(tangent_tup)
39+
else
40+
names = fieldnames(B)
41+
tangent_nt = NamedTuple{names}(tangent_tup)
42+
Tangent{iith_order_type, typeof(tangent_nt)}(tangent_nt)
43+
end
44+
return tangent
45+
end
46+
return TaylorBundle{N, B}(the_primal, the_partials)
47+
end
1448

1549
@generated (::∂☆new{N})(B::Type) where {N} = return :(ZeroBundle{$N}($(Expr(:new, :B))))
1650

test/forward.jl

Lines changed: 44 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,7 @@ let var"'" = Diffractor.PrimeDerivativeFwd
2525
# Integration tests
2626
@test recursive_sin'(1.0) == cos(1.0)
2727
@test recursive_sin''(1.0) == -sin(1.0)
28-
# Error: ArgumentError: Tangent for the primal Tangent{Tuple{Float64, Float64}, Tuple{Float64, Float64}}
29-
# should be backed by a NamedTuple type, not by Tuple{Tangent{Tuple{Float64, Float64}, Tuple{Float64, Float64}}}.
28+
3029
@test_broken recursive_sin'''(1.0) == -cos(1.0)
3130
@test_broken recursive_sin''''(1.0) == sin(1.0)
3231
@test_broken recursive_sin'''''(1.0) == cos(1.0)
@@ -40,6 +39,7 @@ let var"'" = Diffractor.PrimeDerivativeFwd
4039
end
4140

4241
# Some Basic Mixed Mode tests
42+
# TODO: unbreak this
4343
function sin_twice_fwd(x)
4444
let var"'" = Diffractor.PrimeDerivativeFwd
4545
sin''(x)
@@ -90,4 +90,46 @@ end
9090
end
9191
end
9292

93+
94+
@testset "structs" begin
95+
struct IDemo
96+
x::Float64
97+
y::Float64
98+
end
99+
100+
function foo(a)
101+
obj = IDemo(2.0, a)
102+
return obj.x * obj.y
103+
end
104+
105+
let var"'" = Diffractor.PrimeDerivativeFwd
106+
@test foo'(100.0) == 2.0
107+
@test foo''(100.0) == 0.0
108+
end
109+
end
110+
111+
@testset "tuples" begin
112+
function foo(a)
113+
tup = (2.0, a)
114+
return first(tup) * tup[2]
115+
end
116+
117+
let var"'" = Diffractor.PrimeDerivativeFwd
118+
@test foo'(100.0) == 2.0
119+
@test foo''(100.0) == 0.0
120+
end
121+
end
122+
123+
@testset "vararg" begin
124+
function foo(a)
125+
tup = (2.0, a)
126+
return *(tup...)
127+
end
128+
129+
let var"'" = Diffractor.PrimeDerivativeFwd
130+
@test foo'(100.0) == 2.0
131+
@test foo''(100.0) == 0.0
132+
end
133+
end
134+
93135
end

0 commit comments

Comments
 (0)