Skip to content

Commit 7df36ef

Browse files
committed
prep for fearless (Taylor's Version)
1 parent 0a11f37 commit 7df36ef

File tree

5 files changed

+43
-68
lines changed

5 files changed

+43
-68
lines changed

src/interface.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ dx(x) = error("Cotangent space not defined for `$(typeof(x))`. Try a real-valued
6464
For `x` in a one dimensional manifold, map x to the trivial, unital, 1st order
6565
tangent bundle. It should hold that `∀x ⟨∂x(x), dx(x)⟩ = 1`
6666
"""
67-
∂x(x::Real) = ExplicitTangentBundle{1}(x, (one(x),))
67+
∂x(x::Real) = TaylorBundle{1}(x, (one(x),))
6868
∂x(x) = error("Tangent space not defined for `$(typeof(x)).")
6969

7070
struct ∂xⁿ{N}; end

src/jet.jl

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
"""
2-
struct Jet{T, N}
2+
struct Jet{S, T, N}
33
44
Represents the truncated (N-1)-th order Taylor series
55
@@ -15,8 +15,8 @@ For a jet `j`, several operations are supported:
1515
derivatives. Mathematically this corresponds to an infinitessimal ball
1616
around `a`.
1717
"""
18-
struct Jet{T, N}
19-
a::T
18+
struct Jet{S, T, N}
19+
a::S
2020
f₀::T
2121
fₙ::NTuple{N, T}
2222
end
@@ -25,13 +25,13 @@ function ChainRulesCore.rrule(::typeof(Base.getproperty), j::Jet, s)
2525
error("Raw getproperty not allowed in AD code")
2626
end
2727

28-
function Base.:+(j1::Jet{T, N}, j2::Jet{T, N}) where {T, N}
28+
function Base.:+(j1::Jet{S, T, N}, j2::Jet{S, T, N}) where {S, T, N}
2929
@assert j1.a === j2.a
30-
Jet{T, N}(j1.a, j1.f₀ + j2.f₀, map(+, j1.fₙ, j2.fₙ))
30+
Jet{S, T, N}(j1.a, j1.f₀ + j2.f₀, map(+, j1.fₙ, j2.fₙ))
3131
end
3232

33-
function Base.:+(j::Jet{T, N}, x::T) where {T, N}
34-
Jet{T, N}(j.a, j.f₀+x, j.fₙ)
33+
function Base.:+(j::Jet{S, T, N}, x::T) where {S, T, N}
34+
Jet{S, T, N}(j.a, j.f₀+x, j.fₙ)
3535
end
3636

3737
struct One; end
@@ -44,28 +44,28 @@ function ChainRulesCore.rrule(::typeof(+), j::Jet, x::One)
4444
j + x, Δ->(NoTangent(), One(), ZeroTangent())
4545
end
4646

47-
function Base.zero(j::Jet{T, N}) where {T, N}
47+
function Base.zero(j::Jet{S, T, N}) where {S, T, N}
4848
let z = zero(j[0])
49-
Jet{T, N}(j.a, z,
49+
Jet{S, T, N}(j.a, z,
5050
ntuple(_->z, N))
5151
end
5252
end
5353
function ChainRulesCore.rrule(::typeof(Base.zero), j::Jet)
5454
zero(j), Δ->(NoTangent(), ZeroTangent())
5555
end
5656

57-
function Base.getindex(j::Jet{T, N}, i::Integer) where {T, N}
57+
function Base.getindex(j::Jet{S, T, N}, i::Integer) where {S, T, N}
5858
(0 <= i <= N) || throw(BoundsError(j, i))
5959
i == 0 && return j.f₀
6060
return j.fₙ[i]
6161
end
6262

63-
function deriv(j::Jet{T, N}) where {T, N}
64-
Jet{T, N-1}(j.a, j.fₙ[1], Base.tail(j.fₙ))
63+
function deriv(j::Jet{S, T, N}) where {S, T, N}
64+
Jet{S, T, N-1}(j.a, j.fₙ[1], Base.tail(j.fₙ))
6565
end
6666

67-
function integrate(j::Jet{T, N}) where {T, N}
68-
Jet{T, N+1}(j.a, zero(j.f₀), tuple(j.f₀, j.fₙ...))
67+
function integrate(j::Jet{S, T, N}) where {S, T, N}
68+
Jet{S, T, N+1}(j.a, zero(j.f₀), tuple(j.f₀, j.fₙ...))
6969
end
7070

7171
deriv(::NoTangent) = NoTangent()
@@ -188,7 +188,7 @@ function (∂⃖ₙ::∂⃖{N})(::typeof(map), f, a::Array) where {N}
188188
TaylorBundle{N}(x,
189189
(one(x), (zero(x) for i = 1:(N-1))...,)))
190190
@assert isa(∂f, TaylorBundle) || isa(∂f, ExplicitTangentBundle{1})
191-
Jet{typeof(x), N}(x, ∂f.primal,
191+
Jet{typeof(x), typeof(x), N}(x, ∂f.primal,
192192
isa(∂f, ExplicitTangentBundle) ? ∂f.tangent.partials : ∂f.tangent.coeffs)
193193
end
194194
∂⃖ₙ(mapev, js, a)
@@ -239,7 +239,7 @@ expressions for the t′ᵢ that are hopefully easier on the compiler.
239239
end...)
240240
end
241241

242-
@generated function (j::Jet{T, N} where T)(x::TaylorBundle{M}) where {N, M}
242+
@generated function (j::Jet{S, T, N} where {S, T})(x::TaylorBundle{M}) where {N, M}
243243
O = min(M,N)
244244
quote
245245
domain_check(j, x.primal)
@@ -249,12 +249,12 @@ end
249249
end
250250
end
251251

252-
function (j::Jet{T, 1} where T)(x::ExplicitTangentBundle{1})
252+
function (j::Jet{S, T, 1} where {S,T})(x::ExplicitTangentBundle{1})
253253
domain_check(j, x.primal)
254254
coeffs = x.tangent.partials
255255
ExplicitTangentBundle{1}(j[0], (jet_taylor_ev(Val{1}(), coeffs, j),))
256256
end
257257

258-
function (j::Jet{T, N} where T)(x::ExplicitTangentBundle{N, M}) where {N, M}
258+
function (j::Jet{S, T, N} where T)(x::ExplicitTangentBundle{N, M}) where {S, N, M}
259259
error("TODO")
260260
end

src/stage1/forward.jl

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,9 @@ end
3636

3737
function shuffle_down(b::TaylorBundle{N, B}) where {N, B}
3838
TaylorBundle{N-1}(
39-
ExplicitTangentBundle{1}(b.primal, (b.tangent.coeffs[1],)),
39+
TaylorBundle{1}(b.primal, (b.tangent.coeffs[1],)),
4040
ntuple(N-1) do i
41-
ExplicitTangentBundle{1}(b.tangent.coeffs[i], (b.tangent.coeffs[i+1],))
41+
TaylorBundle{1}(b.tangent.coeffs[i], (b.tangent.coeffs[i+1],))
4242
end)
4343
end
4444

@@ -54,13 +54,8 @@ end
5454
function shuffle_up(r::CompositeBundle{1})
5555
z₀ = primal(r.tup[1])
5656
z₁ = partial(r.tup[1], 1)
57-
z₂ = primal(r.tup[2])
5857
z₁₂ = partial(r.tup[2], 1)
59-
if z₁ == z₂
60-
return TaylorBundle{2}(z₀, (z₁, z₁₂))
61-
else
62-
return ExplicitTangentBundle{2}(z₀, (z₁, z₂, z₁₂))
63-
end
58+
return TaylorBundle{2}(z₀, (z₁, z₁₂))
6459
end
6560

6661
function taylor_compatible(a::ATB{N}, b::ATB{N}) where {N}
@@ -118,7 +113,7 @@ function (::∂☆internal{1})(args::AbstractTangentBundle{1}...)
118113
end
119114

120115
function ChainRulesCore.frule_via_ad(::DiffractorRuleConfig, partials, args...)
121-
bundles = map((p,a) -> ExplicitTangentBundle{1}(a, (p,)), partials, args)
116+
bundles = map((p,a) -> TaylorBundle{1}(a, (p,)), partials, args)
122117
result = ∂☆internal{1}()(bundles...)
123118
primal(result), first_partial(result)
124119
end
@@ -139,13 +134,13 @@ end
139134
(::∂☆{N})(args::AbstractTangentBundle{N}...) where {N} = ∂☆internal{N}()(args...)
140135

141136
# Special case rules for performance
142-
@Base.constprop :aggressive function (::∂☆{N})(f::ATB{N, typeof(getfield)}, x::TangentBundle{N}, s::AbstractTangentBundle{N}) where {N}
137+
@Base.constprop :aggressive function (::∂☆{N})(f::ATB{N, typeof(getfield)}, x::ExplicitTangentBundle{N}, s::AbstractTangentBundle{N}) where {N}
143138
s = primal(s)
144139
ExplicitTangentBundle{N}(getfield(primal(x), s),
145140
map(x->lifted_getfield(x, s), x.tangent.partials))
146141
end
147142

148-
@Base.constprop :aggressive function (::∂☆{N})(f::ATB{N, typeof(getfield)}, x::TangentBundle{N}, s::ATB{N}, inbounds::ATB{N}) where {N}
143+
@Base.constprop :aggressive function (::∂☆{N})(f::ATB{N, typeof(getfield)}, x::ExplicitTangentBundle{N}, s::ATB{N}, inbounds::ATB{N}) where {N}
149144
s = primal(s)
150145
ExplicitTangentBundle{N}(getfield(primal(x), s, primal(inbounds)),
151146
map(x->lifted_getfield(x, s), x.tangent.partials))

src/tangent.jl

Lines changed: 17 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -90,9 +90,8 @@ struct ExplicitTangent{P <: Tuple} <: AbstractTangentSpace
9090
partials::P
9191
end
9292

93-
@eval struct TaylorTangent{C <: Tuple} <: AbstractTangentSpace
93+
struct TaylorTangent{C <: Tuple} <: AbstractTangentSpace
9494
coeffs::C
95-
TaylorTangent(coeffs) = $(Expr(:new, :(TaylorTangent{typeof(coeffs)}), :coeffs))
9695
end
9796

9897
"""
@@ -140,24 +139,17 @@ struct UniformTangent{U} <: AbstractTangentSpace
140139
val::U
141140
end
142141

143-
function _TangentBundle end
144-
145-
@eval struct TangentBundle{N, B, P <: AbstractTangentSpace} <: AbstractTangentBundle{N, B}
146-
primal::B
147-
tangent::P
148-
global _TangentBundle(::Val{N}, primal::B, tangent::P) where {N, B, P} = $(Expr(:new, :(TangentBundle{N, Core.Typeof(primal), typeof(tangent)}), :primal, :tangent))
149-
end
150-
151142
"""
152143
struct TangentBundle{N, B, P}
153144
154145
Represents a tangent bundle as an explicit primal together
155146
with some representation of (potentially a product of) the tangent space.
156147
"""
157-
TangentBundle
158-
159-
TangentBundle{N}(primal::B, tangent::P) where {N, B, P<:AbstractTangentSpace} =
160-
_TangentBundle(Val{N}(), primal, tangent)
148+
struct TangentBundle{N, B, P <: AbstractTangentSpace} <: AbstractTangentBundle{N, B}
149+
primal::B
150+
tangent::P
151+
TangentBundle{N}(B, P) where {N} = new{N, typeof(B), typeof(P)}(B,P)
152+
end
161153

162154
const ExplicitTangentBundle{N, B, P} = TangentBundle{N, B, ExplicitTangent{P}}
163155

@@ -166,17 +158,17 @@ check_tangent_invariant(lp, N) = @assert lp == 2^N - 1
166158

167159
function ExplicitTangentBundle{N}(primal::B, partials::P) where {N, B, P}
168160
check_tangent_invariant(length(partials), N)
169-
_TangentBundle(Val{N}(), primal, ExplicitTangent{P}(partials))
161+
TangentBundle{N}(primal, ExplicitTangent{P}(partials))
170162
end
171163

172164
function ExplicitTangentBundle{N,B}(primal::B, partials::P) where {N, B, P}
173165
check_tangent_invariant(length(partials), N)
174-
_TangentBundle(Val{N}(), primal, ExplicitTangent{P}(partials))
166+
TangentBundle{N}(primal, ExplicitTangent{P}(partials))
175167
end
176168

177169
function ExplicitTangentBundle{N,B,P}(primal::B, partials::P) where {N, B, P}
178170
check_tangent_invariant(length(partials), N)
179-
_TangentBundle(Val{N}(), primal, ExplicitTangent{P}(partials))
171+
TangentBundle{N}(primal, ExplicitTangent{P}(partials))
180172
end
181173

182174
function Base.show(io::IO, x::ExplicitTangentBundle)
@@ -203,7 +195,7 @@ const TaylorBundle{N, B, P} = TangentBundle{N, B, TaylorTangent{P}}
203195

204196
function TaylorBundle{N, B}(primal::B, coeffs) where {N, B}
205197
check_taylor_invariants(coeffs, primal, N)
206-
_TangentBundle(Val{N}(), primal, TaylorTangent(coeffs))
198+
TangentBundle{N}(primal, TaylorTangent(coeffs))
207199
end
208200

209201
function check_taylor_invariants(coeffs, primal, N)
@@ -215,7 +207,7 @@ end
215207
@ChainRulesCore.non_differentiable check_taylor_invariants(coeffs, primal, N)
216208

217209
function TaylorBundle{N}(primal, coeffs) where {N}
218-
_TangentBundle(Val{N}(), primal, TaylorTangent(coeffs))
210+
TangentBundle{N}(primal, TaylorTangent(coeffs))
219211
end
220212

221213
function Base.show(io::IO, x::TaylorBundle{1})
@@ -230,25 +222,13 @@ function Base.getindex(tb::TaylorBundle, tti::CanonicalTangentIndex)
230222
tb.tangent.coeffs[count_ones(tti.i)]
231223
end
232224

233-
function truncate(tt::TaylorTangent, order::Val{N}) where {N}
234-
TaylorTangent(tt.coeffs[1:N])
235-
end
236-
237-
function truncate(ut::UniformTangent, order::Val)
238-
ut
239-
end
240-
241-
function truncate(tb::TangentBundle, order::Val)
242-
_TangentBundle(order, tb.primal, truncate(tb.tangent, order))
243-
end
244-
245225
const UniformBundle{N, B, U} = TangentBundle{N, B, UniformTangent{U}}
246-
UniformBundle{N, B, U}(primal::B, partial::U) where {N,B,U} = _TangentBundle(Val{N}(), primal, UniformTangent{U}(partial))
247-
UniformBundle{N, B, U}(primal::B) where {N,B,U} = _TangentBundle(Val{N}(), primal, UniformTangent{U}(U.instance))
248-
UniformBundle{N, B}(primal::B, partial::U) where {N,B,U} = _TangentBundle(Val{N}(),primal, UniformTangent{U}(partial))
249-
UniformBundle{N}(primal, partial::U) where {N,U} = _TangentBundle(Val{N}(), primal, UniformTangent{U}(partial))
250-
UniformBundle{N, <:Any, U}(primal, partial::U) where {N, U} = _TangentBundle(Val{N}(), primal, UniformTangent{U}(U.instance))
251-
UniformBundle{N, <:Any, U}(primal) where {N, U} = _TangentBundle(Val{N}(), primal, UniformTangent{U}(U.instance))
226+
UniformBundle{N, B, U}(primal::B, partial::U) where {N,B,U} = TangentBundle{N}(primal, UniformTangent{U}(partial))
227+
UniformBundle{N, B, U}(primal::B) where {N,B,U} = TangentBundle{N}(primal, UniformTangent{U}(U.instance))
228+
UniformBundle{N, B}(primal::B, partial::U) where {N,B,U} = TangentBundle{N}(primal, UniformTangent{U}(partial))
229+
UniformBundle{N}(primal, partial::U) where {N,U} = TangentBundle{N}(primal, UniformTangent{U}(partial))
230+
UniformBundle{N, <:Any, U}(primal, partial::U) where {N, U} = TangentBundle{N}(primal, UniformTangent{U}(U.instance))
231+
UniformBundle{N, <:Any, U}(primal) where {N, U} = TangentBundle{N}(primal, UniformTangent{U}(U.instance))
252232

253233
const ZeroBundle{N, B} = UniformBundle{N, B, ZeroTangent}
254234
const DNEBundle{N, B} = UniformBundle{N, B, NoTangent}

test/runtests.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ ChainRules.rrule(::typeof(my_tuple), args...) = args, Δ->Core.tuple(NoTangent()
5050

5151
# Minimal 2-nd order forward smoke test
5252
@test Diffractor.∂☆{2}()(Diffractor.ZeroBundle{2}(sin),
53-
Diffractor.ExplicitTangentBundle{2}(1.0, (1.0, 1.0, 0.0)))[Diffractor.CanonicalTangentIndex(1)] == sin'(1.0)
53+
Diffractor.TaylorBundle{2}(1.0, (1.0 0.0)))[Diffractor.CanonicalTangentIndex(1)] == sin'(1.0)
5454

5555
function simple_control_flow(b, x)
5656
if b

0 commit comments

Comments
 (0)