Skip to content

Commit bc4773a

Browse files
authored
WIP: Refactor forward mode data structures (#54)
To allow chunking like ForwardDiff. import `∂⃖¹` some fixes
1 parent d7648f1 commit bc4773a

File tree

6 files changed

+151
-121
lines changed

6 files changed

+151
-121
lines changed

src/interface.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ dx(x) = error("Cotangent space not defined for `$(typeof(x))`. Try a real-valued
6464
For `x` in a one dimensional manifold, map x to the trivial, unital, 1st order
6565
tangent bundle. It should hold that `∀x ⟨∂x(x), dx(x)⟩ = 1`
6666
"""
67-
∂x(x::Real) = TangentBundle{1}(x, (one(x),))
67+
∂x(x::Real) = ExplicitTangentBundle{1}(x, (one(x),))
6868
∂x(x) = error("Tangent space not defined for `$(typeof(x)).")
6969

7070
struct ∂xⁿ{N}; end
@@ -177,7 +177,7 @@ raise_pd(f::PrimeDerivativeFwd{N,T}) where {N,T} = PrimeDerivativeFwd{plus1(N),T
177177

178178
function (f::PrimeDerivativeFwd{1})(x)
179179
z = ∂☆¹(ZeroBundle{1}(getfield(f, :f)), ∂x(x))
180-
z.partials[1]
180+
z.tangent.partials[1]
181181
end
182182

183183
function (f::PrimeDerivativeFwd{N})(x) where N

src/jet.jl

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -187,9 +187,9 @@ function (∂⃖ₙ::∂⃖{N})(::typeof(map), f, a::Array) where {N}
187187
∂f = ∂☆{N}()(ZeroBundle{N}(f),
188188
TaylorBundle{N}(x,
189189
(one(x), (zero(x) for i = 1:(N-1))...,)))
190-
@assert isa(∂f, TaylorBundle) || isa(∂f, TangentBundle{1})
190+
@assert isa(∂f, TaylorBundle) || isa(∂f, ExplicitTangentBundle{1})
191191
Jet{typeof(x), N}(x, ∂f.primal,
192-
isa(∂f, TangentBundle) ? ∂f.partials : ∂f.coeffs)
192+
isa(∂f, ExplicitTangentBundle) ? ∂f.tangent.partials : ∂f.tangent.coeffs)
193193
end
194194
∂⃖ₙ(mapev, js, a)
195195
end
@@ -243,18 +243,18 @@ end
243243
O = min(M,N)
244244
quote
245245
domain_check(j, x.primal)
246-
coeffs = x.coeffs
246+
coeffs = x.tangent.coeffs
247247
TaylorBundle{$O}(j[0],
248248
($((:(jet_taylor_ev(Val{$i}(), coeffs, j)) for i = 1:O)...),))
249249
end
250250
end
251251

252-
function (j::Jet{T, 1} where T)(x::TangentBundle{1})
252+
function (j::Jet{T, 1} where T)(x::ExplicitTangentBundle{1})
253253
domain_check(j, x.primal)
254-
coeffs = x.partials
255-
TangentBundle{1}(j[0], (jet_taylor_ev(Val{1}(), coeffs, j),))
254+
coeffs = x.tangent.partials
255+
ExplicitTangentBundle{1}(j[0], (jet_taylor_ev(Val{1}(), coeffs, j),))
256256
end
257257

258-
function (j::Jet{T, N} where T)(x::TangentBundle{N, M}) where {N, M}
258+
function (j::Jet{T, N} where T)(x::ExplicitTangentBundle{N, M}) where {N, M}
259259
error("TODO")
260260
end

src/stage1/forward.jl

Lines changed: 27 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,14 @@
1-
partial(x::TangentBundle, i) = x.partials[i]
2-
partial(x::TaylorBundle{1}, i) = x.coeffs[i]
3-
partial(x::UniformBundle, i) = x.partial
4-
partial(x::CompositeBundle{N, B}, i) where {N, B} = Tangent{B}(map(x->partial(x, i), x.tup)...)
5-
partial(x::ZeroTangent, i) = ZeroTangent()
1+
partial(x::TangentBundle, i) = partial(getfield(x, :tangent), i)
2+
partial(x::ExplicitTangent, i) = getfield(getfield(x, :partials), i)
3+
partial(x::TaylorTangent, i) = getfield(getfield(x, :coeffs), i)
4+
partial(x::UniformTangent, i) = getfield(x, :val)
5+
partial(x::ProductTangent, i) = ProductTangent(map(x->partial(x, i), getfield(x, :factors)))
6+
partial(x::AbstractZero, i) = x
7+
partial(x::CompositeBundle{N, B}, i) where {N, B} = Tangent{B}(map(x->partial(x, i), getfield(x, :tup))...)
68
primal(x::AbstractTangentBundle) = x.primal
79
primal(z::ZeroTangent) = ZeroTangent()
810

9-
first_partial(x::TangentBundle{1}) = getfield(getfield(x, :partials), 1)
10-
first_partial(x::TaylorBundle{1}) = getfield(getfield(x, :coeffs), 1)
11-
first_partial(x::UniformBundle) = getfield(x, :partial)
12-
first_partial(x::CompositeBundle) = map(first_partial, getfield(x, :tup))
11+
first_partial(x) = partial(x, 1)
1312

1413
# TODO: Which version do we want in ChainRules?
1514
function my_frule(args::ATB{1}...)
@@ -24,22 +23,22 @@ my_frule(::ZeroBundle{1, typeof(my_frule)}, args::ATB{1}...) = nothing
2423
(::∂☆{N})(::ZeroBundle{N, typeof(my_frule)}, ::ZeroBundle{N, ZeroBundle{1, typeof(my_frule)}}, args::ATB{N}...) where {N} = ZeroBundle{N}(nothing)
2524

2625
shuffle_down(b::UniformBundle{N, B, U}) where {N, B, U} =
27-
UniformBundle{minus1(N), <:Any, U}(UniformBundle{1, B, U}(b.primal, b.partial), b.partial)
26+
UniformBundle{minus1(N), <:Any, U}(UniformBundle{1, B, U}(b.primal, b.tangent.val), b.tangent.val)
2827

29-
function shuffle_down(b::TangentBundle{N, B}) where {N, B}
28+
function shuffle_down(b::ExplicitTangentBundle{N, B}) where {N, B}
3029
# N.B: This depends on the special properties of the canonical tangent index order
31-
TangentBundle{N-1}(
32-
TangentBundle{1}(b.primal, (partial(b, 1),)),
30+
ExplicitTangentBundle{N-1}(
31+
ExplicitTangentBundle{1}(b.primal, (partial(b, 1),)),
3332
ntuple(2^(N-1)-1) do i
34-
TangentBundle{1}(partial(b, 2*i), (partial(b, 2*i+1),))
33+
ExplicitTangentBundle{1}(partial(b, 2*i), (partial(b, 2*i+1),))
3534
end)
3635
end
3736

3837
function shuffle_down(b::TaylorBundle{N, B}) where {N, B}
3938
TaylorBundle{N-1}(
40-
TangentBundle{1}(b.primal, (b.coeffs[1],)),
39+
ExplicitTangentBundle{1}(b.primal, (b.tangent.coeffs[1],)),
4140
ntuple(N-1) do i
42-
TangentBundle{1}(b.coeffs[i], (b.coeffs[i+1],))
41+
ExplicitTangentBundle{1}(b.tangent.coeffs[i], (b.tangent.coeffs[i+1],))
4342
end)
4443
end
4544

@@ -60,7 +59,7 @@ function shuffle_up(r::CompositeBundle{1})
6059
if z₁ == z₂
6160
return TaylorBundle{2}(z₀, (z₁, z₁₂))
6261
else
63-
return TangentBundle{2}(z₀, (z₁, z₂, z₁₂))
62+
return ExplicitTangentBundle{2}(z₀, (z₁, z₂, z₁₂))
6463
end
6564
end
6665

@@ -86,14 +85,14 @@ function shuffle_up(r::CompositeBundle{N}) where {N}
8685
N+1))
8786
else
8887
return TangentBundle{N+1}(r.tup[1].primal,
89-
(r.tup[1].partials..., primal(b),
88+
(r.tup[1].tangent.partials..., primal(b),
9089
ntuple(i->partial(b,i), 2^(N+1)-1)...))
9190
end
9291
end
9392

9493
function shuffle_up(r::UniformBundle{N, B, U}) where {N, B, U}
9594
(a, b) = primal(r)
96-
if r.partial === b
95+
if r.tangent.val === b
9796
u = b
9897
elseif b == NoTangent() && U === ZeroTangent
9998
u = b
@@ -107,7 +106,7 @@ end
107106
struct ∂☆internal{N}; end
108107
struct ∂☆shuffle{N}; end
109108

110-
shuffle_base(r) = TangentBundle{1}(r[1], (r[2],))
109+
shuffle_base(r) = ExplicitTangentBundle{1}(r[1], (r[2],))
111110

112111
function (::∂☆internal{1})(args::AbstractTangentBundle{1}...)
113112
r = my_frule(args...)
@@ -119,7 +118,7 @@ function (::∂☆internal{1})(args::AbstractTangentBundle{1}...)
119118
end
120119

121120
function ChainRulesCore.frule_via_ad(::DiffractorRuleConfig, partials, args...)
122-
bundles = map((p,a) -> TangentBundle{1}(a, (p,)), partials, args)
121+
bundles = map((p,a) -> ExplicitTangentBundle{1}(a, (p,)), partials, args)
123122
result = ∂☆internal{1}()(bundles...)
124123
primal(result), first_partial(result)
125124
end
@@ -142,14 +141,14 @@ end
142141
# Special case rules for performance
143142
@Base.constprop :aggressive function (::∂☆{N})(f::ATB{N, typeof(getfield)}, x::TangentBundle{N}, s::AbstractTangentBundle{N}) where {N}
144143
s = primal(s)
145-
TangentBundle{N}(getfield(primal(x), s),
146-
map(x->lifted_getfield(x, s), x.partials))
144+
ExplicitTangentBundle{N}(getfield(primal(x), s),
145+
map(x->lifted_getfield(x, s), x.tangent.partials))
147146
end
148147

149148
@Base.constprop :aggressive function (::∂☆{N})(f::ATB{N, typeof(getfield)}, x::TaylorBundle{N}, s::AbstractTangentBundle{N}) where {N}
150149
s = primal(s)
151150
TaylorBundle{N}(getfield(primal(x), s),
152-
map(y->lifted_getfield(y, s), x.coeffs))
151+
map(y->lifted_getfield(y, s), x.tangent.coeffs))
153152
end
154153

155154
@Base.constprop :aggressive function (::∂☆{N})(::ATB{N, typeof(getfield)}, x::CompositeBundle{N}, s::AbstractTangentBundle{N, Int}) where {N}
@@ -162,16 +161,16 @@ end
162161

163162
@Base.constprop :aggressive function (::∂☆{N})(f::ATB{N, typeof(getfield)}, x::ATB{N}, s::ATB{N}, inbounds::ATB{N}) where {N}
164163
s = primal(s)
165-
TangentBundle{N}(getfield(primal(x), s, primal(inbounds)),
166-
map(x->lifted_getfield(x, s), x.partials))
164+
ExplicitTangentBundle{N}(getfield(primal(x), s, primal(inbounds)),
165+
map(x->lifted_getfield(x, s), x.tangent.partials))
167166
end
168167

169168
@Base.constprop :aggressive function (::∂☆{N})(f::ATB{N, typeof(getfield)}, x::UniformBundle{N, <:Any, U}, s::AbstractTangentBundle{N}) where {N, U}
170-
UniformBundle{N,<:Any,U}(getfield(primal(x), primal(s)), x.partial)
169+
UniformBundle{N,<:Any,U}(getfield(primal(x), primal(s)), x.tangent.val)
171170
end
172171

173172
@Base.constprop :aggressive function (::∂☆{N})(f::ATB{N, typeof(getfield)}, x::UniformBundle{N, <:Any, U}, s::AbstractTangentBundle{N}, inbounds::AbstractTangentBundle{N}) where {N, U}
174-
UniformBundle{N,<:Any,U}(getfield(primal(x), primal(s), primal(inbounds)), x.partial)
173+
UniformBundle{N,<:Any,U}(getfield(primal(x), primal(s), primal(inbounds)), x.tangent.val)
175174
end
176175

177176
function (::∂☆{N})(f::ATB{N, typeof(tuple)}, args::AbstractTangentBundle{N}...) where {N}

src/stage1/mixed.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -95,9 +95,9 @@ function (∂⃖ₙ::∂⃖{N})(∂☆ₘ::∂☆{M}, ::ZeroBundle{M, typeof(map
9595
∂f = ∂☆{N+M}()(ZeroBundle{N+M}(primal(f)),
9696
TaylorBundle{N+M}(x,
9797
(one(x), (zero(x) for i = 1:(N+M-1))...,)))
98-
@assert isa(∂f, TaylorBundle) || isa(∂f, TangentBundle{1})
98+
@assert isa(∂f, TaylorBundle) || isa(∂f, ExplicitTangentBundle{1})
9999
Jet{typeof(x), N+M}(x, ∂f.primal,
100-
isa(∂f, TangentBundle) ? ∂f.partials : ∂f.coeffs)
100+
isa(∂f, ExplicitTangentBundle) ? ∂f.tangent.partials : ∂f.tangent.coeffs)
101101
end
102102
∂⃖ₙ(mapev_unbundled, ∂☆ₘ, js, a)
103103
end

0 commit comments

Comments
 (0)