@@ -90,6 +90,11 @@ struct ExplicitTangent{P <: Tuple} <: AbstractTangentSpace
90
90
partials:: P
91
91
end
92
92
93
+ @eval struct TaylorTangent{C <: Tuple } <: AbstractTangentSpace
94
+ coeffs:: C
95
+ TaylorTangent (coeffs) = $ (Expr (:new , :(TaylorTangent{typeof (coeffs)}), :coeffs ))
96
+ end
97
+
93
98
"""
94
99
struct TaylorTangent{C}
95
100
@@ -113,9 +118,7 @@ by analogy with the (truncated) Taylor series
113
118
114
119
c₀ + c₁ x + 1/2 c₂ x² + 1/3! c₃ x³ + O(x⁴)
115
120
"""
116
- struct TaylorTangent{C <: Tuple } <: AbstractTangentSpace
117
- coeffs:: C
118
- end
121
+ TaylorTangent
119
122
120
123
"""
121
124
struct ProductTangent{T <: Tuple{Vararg{AbstractTangentSpace}}}
@@ -137,20 +140,24 @@ struct UniformTangent{U} <: AbstractTangentSpace
137
140
val:: U
138
141
end
139
142
143
+ function _TangentBundle end
144
+
145
+ @eval struct TangentBundle{N, B, P <: AbstractTangentSpace } <: AbstractTangentBundle{N, B}
146
+ primal:: B
147
+ tangent:: P
148
+ global _TangentBundle (:: Val{N} , primal:: B , tangent:: P ) where {N, B, P} = $ (Expr (:new , :(TangentBundle{N, Core. Typeof (primal), typeof (tangent)}), :primal , :tangent ))
149
+ end
150
+
140
151
"""
141
152
struct TangentBundle{N, B, P}
142
153
143
154
Represents a tangent bundle as an explicit primal together
144
155
with some representation of (potentially a product of) the tangent space.
145
156
"""
146
- struct TangentBundle{N, B, P <: AbstractTangentSpace } <: AbstractTangentBundle{N, B}
147
- primal:: B
148
- tangent:: P
149
- TangentBundle {N, B, P} (primal:: B , tangent:: P ) where {N, B, P} = new {N, B, P} (primal, tangent)
150
- end
157
+ TangentBundle
151
158
152
159
TangentBundle {N} (primal:: B , tangent:: P ) where {N, B, P<: AbstractTangentSpace } =
153
- TangentBundle {N, B, P} ( primal, tangent)
160
+ _TangentBundle ( Val {N} (), primal, tangent)
154
161
155
162
const ExplicitTangentBundle{N, B, P} = TangentBundle{N, B, ExplicitTangent{P}}
156
163
@@ -159,17 +166,17 @@ check_tangent_invariant(lp, N) = @assert lp == 2^N - 1
159
166
160
167
function ExplicitTangentBundle {N} (primal:: B , partials:: P ) where {N, B, P}
161
168
check_tangent_invariant (length (partials), N)
162
- TangentBundle {N, Core.Typeof(primal ), ExplicitTangent{P}} ( primal, ExplicitTangent {P} (partials))
169
+ _TangentBundle ( Val {N} ( ), primal, ExplicitTangent {P} (partials))
163
170
end
164
171
165
172
function ExplicitTangentBundle {N,B} (primal:: B , partials:: P ) where {N, B, P}
166
173
check_tangent_invariant (length (partials), N)
167
- TangentBundle {N, B, ExplicitTangent{P}} ( primal, ExplicitTangent {P} (partials))
174
+ _TangentBundle ( Val {N} (), primal, ExplicitTangent {P} (partials))
168
175
end
169
176
170
177
function ExplicitTangentBundle {N,B,P} (primal:: B , partials:: P ) where {N, B, P}
171
178
check_tangent_invariant (length (partials), N)
172
- TangentBundle {N, B, ExplicitTangent{P}} ( primal, ExplicitTangent {P} (partials))
179
+ _TangentBundle ( Val {N} (), primal, ExplicitTangent {P} (partials))
173
180
end
174
181
175
182
function Base. show (io:: IO , x:: ExplicitTangentBundle )
194
201
195
202
const TaylorBundle{N, B, P} = TangentBundle{N, B, TaylorTangent{P}}
196
203
197
- function TaylorBundle {N, B} (primal:: B , coeffs:: P ) where {N, B, P }
204
+ function TaylorBundle {N, B} (primal:: B , coeffs) where {N, B}
198
205
check_taylor_invariants (coeffs, primal, N)
199
- TangentBundle {N, B, TaylorTangent{P}} ( primal, TaylorTangent {P} (coeffs))
206
+ _TangentBundle ( Val {N} (), primal, TaylorTangent (coeffs))
200
207
end
201
208
202
209
function check_taylor_invariants (coeffs, primal, N)
208
215
@ChainRulesCore . non_differentiable check_taylor_invariants (coeffs, primal, N)
209
216
210
217
function TaylorBundle {N} (primal, coeffs) where {N}
211
- TaylorBundle {N, Core.Typeof( primal)} (primal, coeffs)
218
+ _TangentBundle ( Val {N} (), primal, TaylorTangent ( coeffs) )
212
219
end
213
220
214
221
function Base. show (io:: IO , x:: TaylorBundle{1} )
@@ -224,12 +231,12 @@ function Base.getindex(tb::TaylorBundle, tti::CanonicalTangentIndex)
224
231
end
225
232
226
233
const UniformBundle{N, B, U} = TangentBundle{N, B, UniformTangent{U}}
227
- UniformBundle {N, B, U} (primal:: B , partial:: U ) where {N,B,U} = UniformBundle {N, B, U} ( primal, UniformTangent {U} (partial))
228
- UniformBundle {N, B, U} (primal:: B ) where {N,B,U} = UniformBundle {N, B, U} ( primal, UniformTangent {U} (U. instance))
229
- UniformBundle {N, B} (primal:: B , partial:: U ) where {N,B,U} = UniformBundle {N, Core.Typeof(primal), U} ( primal, UniformTangent {U} (partial))
230
- UniformBundle {N} (primal, partial:: U ) where {N,U} = UniformBundle {N, Core.Typeof(primal ), U} ( primal, UniformTangent {U} (partial))
231
- UniformBundle {N, <:Any, U} (primal, partial:: U ) where {N, U} = UniformBundle {N, Core.Typeof(primal ), U} ( primal, UniformTangent {U} (U. instance))
232
- UniformBundle {N, <:Any, U} (primal) where {N, U} = UniformBundle {N, Core.Typeof(primal ), U} ( primal, UniformTangent {U} (U. instance))
234
+ UniformBundle {N, B, U} (primal:: B , partial:: U ) where {N,B,U} = _TangentBundle ( Val {N} (), primal, UniformTangent {U} (partial))
235
+ UniformBundle {N, B, U} (primal:: B ) where {N,B,U} = _TangentBundle ( Val {N} (), primal, UniformTangent {U} (U. instance))
236
+ UniformBundle {N, B} (primal:: B , partial:: U ) where {N,B,U} = _TangentBundle ( Val {N} (), primal, UniformTangent {U} (partial))
237
+ UniformBundle {N} (primal, partial:: U ) where {N,U} = _TangentBundle ( Val {N} ( ), primal, UniformTangent {U} (partial))
238
+ UniformBundle {N, <:Any, U} (primal, partial:: U ) where {N, U} = _TangentBundle ( Val {N} ( ), primal, UniformTangent {U} (U. instance))
239
+ UniformBundle {N, <:Any, U} (primal) where {N, U} = _TangentBundle ( Val {N} ( ), primal, UniformTangent {U} (U. instance))
233
240
234
241
const ZeroBundle{N, B} = UniformBundle{N, B, ZeroTangent}
235
242
const DNEBundle{N, B} = UniformBundle{N, B, NoTangent}
0 commit comments