Skip to content

Commit 0241872

Browse files
authored
Improve inferability of forward mode data structures (#104)
None of this is particularly, pretty, but these changes help the compiler constprop things better, so let's do this for now. We can revisit in the future when the compiler learns more tricks.
1 parent c6f1a48 commit 0241872

File tree

3 files changed

+41
-28
lines changed

3 files changed

+41
-28
lines changed

src/extra_rules.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,7 @@ function ChainRules.rrule(::DiffractorRuleConfig, ::Type{SArray{S, T, N, L}}, x:
172172
end
173173

174174
function ChainRules.frule((_, ∂x), ::Type{SArray{S, T, N, L}}, x::NTuple{L,T}) where {S, T, N, L}
175-
SArray{S, T, N, L}(x), SArray{S, T, N, L}(∂x)
175+
SArray{S, T, N, L}(x), SArray{S, T, N, L}(∂x.backing)
176176
end
177177

178178
function ChainRules.frule((_, ∂x), ::Type{SArray{S, T, N, L}}, x::NTuple{L,Any}) where {S, T, N, L}

src/stage1/forward.jl

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -145,12 +145,24 @@ end
145145
map(x->lifted_getfield(x, s), x.tangent.partials))
146146
end
147147

148+
@Base.constprop :aggressive function (::∂☆{N})(f::ATB{N, typeof(getfield)}, x::TangentBundle{N}, s::ATB{N}, inbounds::ATB{N}) where {N}
149+
s = primal(s)
150+
ExplicitTangentBundle{N}(getfield(primal(x), s, primal(inbounds)),
151+
map(x->lifted_getfield(x, s), x.tangent.partials))
152+
end
153+
148154
@Base.constprop :aggressive function (::∂☆{N})(f::ATB{N, typeof(getfield)}, x::TaylorBundle{N}, s::AbstractTangentBundle{N}) where {N}
149155
s = primal(s)
150156
TaylorBundle{N}(getfield(primal(x), s),
151157
map(y->lifted_getfield(y, s), x.tangent.coeffs))
152158
end
153159

160+
@Base.constprop :aggressive function (::∂☆{N})(f::ATB{N, typeof(getfield)}, x::TaylorBundle{N}, s::AbstractTangentBundle{N}, inbounds::ATB{N}) where {N}
161+
s = primal(s)
162+
TaylorBundle{N}(getfield(primal(x), s, primal(inbounds)),
163+
map(y->lifted_getfield(y, s), x.tangent.coeffs))
164+
end
165+
154166
@Base.constprop :aggressive function (::∂☆{N})(::ATB{N, typeof(getfield)}, x::CompositeBundle{N}, s::AbstractTangentBundle{N, Int}) where {N}
155167
x.tup[primal(s)]
156168
end
@@ -159,12 +171,6 @@ end
159171
x.tup[Base.fieldindex(B, primal(s))]
160172
end
161173

162-
@Base.constprop :aggressive function (::∂☆{N})(f::ATB{N, typeof(getfield)}, x::ATB{N}, s::ATB{N}, inbounds::ATB{N}) where {N}
163-
s = primal(s)
164-
ExplicitTangentBundle{N}(getfield(primal(x), s, primal(inbounds)),
165-
map(x->lifted_getfield(x, s), x.tangent.partials))
166-
end
167-
168174
@Base.constprop :aggressive function (::∂☆{N})(f::ATB{N, typeof(getfield)}, x::UniformBundle{N, <:Any, U}, s::AbstractTangentBundle{N}) where {N, U}
169175
UniformBundle{N,<:Any,U}(getfield(primal(x), primal(s)), x.tangent.val)
170176
end

src/tangent.jl

Lines changed: 28 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,11 @@ struct ExplicitTangent{P <: Tuple} <: AbstractTangentSpace
9090
partials::P
9191
end
9292

93+
@eval struct TaylorTangent{C <: Tuple} <: AbstractTangentSpace
94+
coeffs::C
95+
TaylorTangent(coeffs) = $(Expr(:new, :(TaylorTangent{typeof(coeffs)}), :coeffs))
96+
end
97+
9398
"""
9499
struct TaylorTangent{C}
95100
@@ -113,9 +118,7 @@ by analogy with the (truncated) Taylor series
113118
114119
c₀ + c₁ x + 1/2 c₂ x² + 1/3! c₃ x³ + O(x⁴)
115120
"""
116-
struct TaylorTangent{C <: Tuple} <: AbstractTangentSpace
117-
coeffs::C
118-
end
121+
TaylorTangent
119122

120123
"""
121124
struct ProductTangent{T <: Tuple{Vararg{AbstractTangentSpace}}}
@@ -137,20 +140,24 @@ struct UniformTangent{U} <: AbstractTangentSpace
137140
val::U
138141
end
139142

143+
function _TangentBundle end
144+
145+
@eval struct TangentBundle{N, B, P <: AbstractTangentSpace} <: AbstractTangentBundle{N, B}
146+
primal::B
147+
tangent::P
148+
global _TangentBundle(::Val{N}, primal::B, tangent::P) where {N, B, P} = $(Expr(:new, :(TangentBundle{N, Core.Typeof(primal), typeof(tangent)}), :primal, :tangent))
149+
end
150+
140151
"""
141152
struct TangentBundle{N, B, P}
142153
143154
Represents a tangent bundle as an explicit primal together
144155
with some representation of (potentially a product of) the tangent space.
145156
"""
146-
struct TangentBundle{N, B, P <: AbstractTangentSpace} <: AbstractTangentBundle{N, B}
147-
primal::B
148-
tangent::P
149-
TangentBundle{N, B, P}(primal::B, tangent::P) where {N, B, P} = new{N, B, P}(primal, tangent)
150-
end
157+
TangentBundle
151158

152159
TangentBundle{N}(primal::B, tangent::P) where {N, B, P<:AbstractTangentSpace} =
153-
TangentBundle{N, B, P}(primal, tangent)
160+
_TangentBundle(Val{N}(), primal, tangent)
154161

155162
const ExplicitTangentBundle{N, B, P} = TangentBundle{N, B, ExplicitTangent{P}}
156163

@@ -159,17 +166,17 @@ check_tangent_invariant(lp, N) = @assert lp == 2^N - 1
159166

160167
function ExplicitTangentBundle{N}(primal::B, partials::P) where {N, B, P}
161168
check_tangent_invariant(length(partials), N)
162-
TangentBundle{N, Core.Typeof(primal), ExplicitTangent{P}}(primal, ExplicitTangent{P}(partials))
169+
_TangentBundle(Val{N}(), primal, ExplicitTangent{P}(partials))
163170
end
164171

165172
function ExplicitTangentBundle{N,B}(primal::B, partials::P) where {N, B, P}
166173
check_tangent_invariant(length(partials), N)
167-
TangentBundle{N, B, ExplicitTangent{P}}(primal, ExplicitTangent{P}(partials))
174+
_TangentBundle(Val{N}(), primal, ExplicitTangent{P}(partials))
168175
end
169176

170177
function ExplicitTangentBundle{N,B,P}(primal::B, partials::P) where {N, B, P}
171178
check_tangent_invariant(length(partials), N)
172-
TangentBundle{N, B, ExplicitTangent{P}}(primal, ExplicitTangent{P}(partials))
179+
_TangentBundle(Val{N}(), primal, ExplicitTangent{P}(partials))
173180
end
174181

175182
function Base.show(io::IO, x::ExplicitTangentBundle)
@@ -194,9 +201,9 @@ end
194201

195202
const TaylorBundle{N, B, P} = TangentBundle{N, B, TaylorTangent{P}}
196203

197-
function TaylorBundle{N, B}(primal::B, coeffs::P) where {N, B, P}
204+
function TaylorBundle{N, B}(primal::B, coeffs) where {N, B}
198205
check_taylor_invariants(coeffs, primal, N)
199-
TangentBundle{N, B, TaylorTangent{P}}(primal, TaylorTangent{P}(coeffs))
206+
_TangentBundle(Val{N}(), primal, TaylorTangent(coeffs))
200207
end
201208

202209
function check_taylor_invariants(coeffs, primal, N)
@@ -208,7 +215,7 @@ end
208215
@ChainRulesCore.non_differentiable check_taylor_invariants(coeffs, primal, N)
209216

210217
function TaylorBundle{N}(primal, coeffs) where {N}
211-
TaylorBundle{N, Core.Typeof(primal)}(primal, coeffs)
218+
_TangentBundle(Val{N}(), primal, TaylorTangent(coeffs))
212219
end
213220

214221
function Base.show(io::IO, x::TaylorBundle{1})
@@ -224,12 +231,12 @@ function Base.getindex(tb::TaylorBundle, tti::CanonicalTangentIndex)
224231
end
225232

226233
const UniformBundle{N, B, U} = TangentBundle{N, B, UniformTangent{U}}
227-
UniformBundle{N, B, U}(primal::B, partial::U) where {N,B,U} = UniformBundle{N, B, U}(primal, UniformTangent{U}(partial))
228-
UniformBundle{N, B, U}(primal::B) where {N,B,U} = UniformBundle{N, B, U}(primal, UniformTangent{U}(U.instance))
229-
UniformBundle{N, B}(primal::B, partial::U) where {N,B,U} = UniformBundle{N, Core.Typeof(primal), U}(primal, UniformTangent{U}(partial))
230-
UniformBundle{N}(primal, partial::U) where {N,U} = UniformBundle{N, Core.Typeof(primal), U}(primal, UniformTangent{U}(partial))
231-
UniformBundle{N, <:Any, U}(primal, partial::U) where {N, U} = UniformBundle{N, Core.Typeof(primal), U}(primal, UniformTangent{U}(U.instance))
232-
UniformBundle{N, <:Any, U}(primal) where {N, U} = UniformBundle{N, Core.Typeof(primal), U}(primal, UniformTangent{U}(U.instance))
234+
UniformBundle{N, B, U}(primal::B, partial::U) where {N,B,U} = _TangentBundle(Val{N}(), primal, UniformTangent{U}(partial))
235+
UniformBundle{N, B, U}(primal::B) where {N,B,U} = _TangentBundle(Val{N}(), primal, UniformTangent{U}(U.instance))
236+
UniformBundle{N, B}(primal::B, partial::U) where {N,B,U} = _TangentBundle(Val{N}(),primal, UniformTangent{U}(partial))
237+
UniformBundle{N}(primal, partial::U) where {N,U} = _TangentBundle(Val{N}(), primal, UniformTangent{U}(partial))
238+
UniformBundle{N, <:Any, U}(primal, partial::U) where {N, U} = _TangentBundle(Val{N}(), primal, UniformTangent{U}(U.instance))
239+
UniformBundle{N, <:Any, U}(primal) where {N, U} = _TangentBundle(Val{N}(), primal, UniformTangent{U}(U.instance))
233240

234241
const ZeroBundle{N, B} = UniformBundle{N, B, ZeroTangent}
235242
const DNEBundle{N, B} = UniformBundle{N, B, NoTangent}

0 commit comments

Comments
 (0)