@@ -74,11 +74,8 @@ struct TaylorTangentIndex <: TangentIndex
74
74
i:: Int
75
75
end
76
76
77
- function Base. getindex (a:: AbstractTangentBundle , b:: TaylorTangentIndex )
78
- error (" $(typeof (a)) is not taylor-like. Taylor indexing is ambiguous" )
79
- end
80
-
81
77
abstract type AbstractTangentSpace; end
78
+ Base.:(== )(x:: AbstractTangentSpace , y:: AbstractTangentSpace ) = == (promote (x, y)... )
82
79
83
80
"""
84
81
struct ExplicitTangent{P}
@@ -89,13 +86,23 @@ represented by a vector of `2^N-1` partials.
89
86
struct ExplicitTangent{P <: Tuple } <: AbstractTangentSpace
90
87
partials:: P
91
88
end
89
+ Base.:(== )(a:: ExplicitTangent , b:: ExplicitTangent ) = a. partials == b. partials
90
+ Base. hash (tt:: ExplicitTangent , h:: UInt64 ) = hash (tt. partials, h)
91
+
92
+ Base. getindex (tangent:: ExplicitTangent , b:: CanonicalTangentIndex ) = tangent. partials[b. i]
93
+ function Base. getindex (tangent:: ExplicitTangent , b:: TaylorTangentIndex )
94
+ if lastindex (tangent. partials) == exp2 (b. i) - 1
95
+ return tangent. partials[end ]
96
+ end
97
+ # TODO : should we also allow other indexes if all the partials at that level are equal up regardless of order?
98
+ throw (DomainError (b, " $(typeof (tangent)) is not taylor-like. Taylor indexing is ambiguous" ))
99
+ end
100
+
92
101
93
102
@eval struct TaylorTangent{C <: Tuple } <: AbstractTangentSpace
94
103
coeffs:: C
95
104
TaylorTangent (coeffs) = $ (Expr (:new , :(TaylorTangent{typeof (coeffs)}), :coeffs ))
96
105
end
97
- Base.:(== )(a:: TaylorTangent , b:: TaylorTangent ) = a. coeffs == b. coeffs
98
- Base. hash (tt:: TaylorTangent , h:: UInt64 ) = hash (tt. coeffs, h)
99
106
100
107
"""
101
108
struct TaylorTangent{C}
@@ -122,15 +129,13 @@ by analogy with the (truncated) Taylor series
122
129
"""
123
130
TaylorTangent
124
131
125
- """
126
- struct ProductTangent{T <: Tuple{Vararg{AbstractTangentSpace}}}
132
+ Base.:(== )(a:: TaylorTangent , b:: TaylorTangent ) = a. coeffs == b. coeffs
133
+ Base. hash (tt:: TaylorTangent , h:: UInt64 ) = hash (tt. coeffs, h)
134
+
135
+
136
+ Base. getindex (tangent:: TaylorTangent , tti:: TaylorTangentIndex ) = tangent. coeffs[tti. i]
137
+ Base. getindex (tangent:: TaylorTangent , tti:: CanonicalTangentIndex ) = tangent. coeffs[count_ones (tti. i)]
127
138
128
- Represents the product space of the given representations of the
129
- tangent space.
130
- """
131
- struct ProductTangent{T <: Tuple } <: AbstractTangentSpace
132
- factors:: T
133
- end
134
139
135
140
"""
136
141
struct UniformTangent
@@ -141,6 +146,28 @@ useful for representing singleton values.
141
146
struct UniformTangent{U} <: AbstractTangentSpace
142
147
val:: U
143
148
end
149
+ Base. hash (t:: UniformTangent , h:: UInt64 ) = hash (t. val, h)
150
+ Base.:(== )(t1:: UniformTangent , t2:: UniformTangent ) = t1. val == t2. val
151
+
152
+ Base. getindex (tangent:: UniformTangent , :: Any ) = tangent. val
153
+
154
+ # Conversion and promotion
155
+ Base. promote_rule (et:: Type{<:ExplicitTangent} , :: Type{<:AbstractTangentSpace} ) = et
156
+ Base. promote_rule (tt:: Type{<:TaylorTangent} , :: Type{<:AbstractTangentSpace} ) = tt
157
+ Base. promote_rule (et:: Type{<:ExplicitTangent} , :: Type{<:TaylorTangent} ) = et
158
+ Base. promote_rule (:: Type{<:TaylorTangent} , et:: Type{<:ExplicitTangent} ) = et
159
+
160
+ num_partials (:: Type{TaylorTangent{P}} ) where P = fieldcount (P)
161
+ num_partials (:: Type{ExplicitTangent{P}} ) where P = fieldcount (P)
162
+ Base. eltype (:: Type{TaylorTangent{P}} ) where P = eltype (P)
163
+ Base. eltype (:: Type{ExplicitTangent{P}} ) where P = eltype (P)
164
+ function Base. convert (:: Type{T} , ut:: UniformTangent ) where {T<: Union{TaylorTangent, ExplicitTangent} }
165
+ # can't just use T to construct as the inner constructor doesn't accept type params. So get T_wrapper
166
+ T_wrapper = T<: TaylorTangent ? TaylorTangent : ExplicitTangent
167
+ T_wrapper (ntuple (_-> convert (eltype (T), ut. val), num_partials (T)))
168
+ end
169
+ Base. convert (T:: Type{<:ExplicitTangent} , tt:: TaylorTangent ) = ExplicitTangent (ntuple (i-> tt[CanonicalTangentIndex (i)], num_partials (T)))
170
+ # TODO : Should we define the reverse: Explict->Taylor for the cases where that is actually defined?
144
171
145
172
function _TangentBundle end
146
173
@@ -154,15 +181,17 @@ end
154
181
struct TangentBundle{N, B, P}
155
182
156
183
Represents a tangent bundle as an explicit primal together
157
- with some representation of (potentially a product of) the tangent space.
184
+ with some representation of the tangent space.
158
185
"""
159
186
TangentBundle
160
187
161
188
TangentBundle {N} (primal:: B , tangent:: P ) where {N, B, P<: AbstractTangentSpace } =
162
189
_TangentBundle (Val {N} (), primal, tangent)
163
190
164
191
Base. hash (tb:: TangentBundle , h:: UInt64 ) = hash (tb. primal, h)
165
- Base.:(== )(a:: TangentBundle , b:: TangentBundle ) = (a. primal == b. primal) && (a. tangent == b. tangent)
192
+ Base.:(== )(a:: TangentBundle , b:: TangentBundle ) = false # different orders
193
+ Base.:(== )(a:: TangentBundle{N} , b:: TangentBundle{N} ) where {N} = (a. primal == b. primal) && (a. tangent == b. tangent)
194
+ Base. getindex (tbun:: TangentBundle , x) = getindex (tbun. tangent, x)
166
195
167
196
const ExplicitTangentBundle{N, B, P} = TangentBundle{N, B, ExplicitTangent{P}}
168
197
@@ -197,12 +226,7 @@ function Base.show(io::IO, x::ExplicitTangentBundle)
197
226
length (x. partials) >= 7 && print (io, " + " , x. partials[7 ], " ∂₁ ∂₂ ∂₃" )
198
227
end
199
228
200
- function Base. getindex (a:: ExplicitTangentBundle{N} , b:: TaylorTangentIndex ) where {N}
201
- if b. i === N
202
- return a. tangent. partials[end ]
203
- end
204
- error (" $(typeof (a)) is not taylor-like. Taylor indexing is ambiguous" )
205
- end
229
+
206
230
207
231
const TaylorBundle{N, B, P} = TangentBundle{N, B, TaylorTangent{P}}
208
232
@@ -233,11 +257,6 @@ function Base.show(io::IO, x::TaylorBundle{1})
233
257
print (io, x. coeffs[1 ], " ∂₁" )
234
258
end
235
259
236
- Base. getindex (tb:: TaylorBundle , tti:: TaylorTangentIndex ) = tb. tangent. coeffs[tti. i]
237
- function Base. getindex (tb:: TaylorBundle , tti:: CanonicalTangentIndex )
238
- tb. tangent. coeffs[count_ones (tti. i)]
239
- end
240
-
241
260
" for a TaylorTangent{N, <:Tuple} this breaks it up unto 1 TaylorTangent{N} for each element of the primal tuple"
242
261
function destructure (r:: TaylorBundle{N, B} ) where {N, B<: Tuple }
243
262
return ntuple (fieldcount (B)) do field_ii
@@ -307,8 +326,18 @@ function Base.show(io::IO, t::AbstractZeroBundle{N}) where N
307
326
print (io, " )" )
308
327
end
309
328
329
+ # Conversion and promotion
330
+ function Base. promote_rule (:: Type{TangentBundle{N, B, P1}} , :: Type{TangentBundle{N, B, P2}} ) where {N,B,P1,P2}
331
+ return TangentBundle{N, B, promote_type (P1, P2)}
332
+ end
333
+
334
+ function Base. convert (:: Type{T} , tbun:: TangentBundle{N, B} ) where {N, B, P, T<: TangentBundle{N,B,P} }
335
+ the_primal = convert (B, primal (tbun))
336
+ the_partials = convert (P, tbun. tangent)
337
+ return _TangentBundle (Val {N} (), the_primal, the_partials)
338
+ end
310
339
311
- Base . getindex (u :: UniformBundle , :: TaylorTangentIndex ) = u . tangent . val
340
+ # StructureArrays helpers
312
341
313
342
expand_singleton_to_array (asize, a:: AbstractZero ) = fill (a, asize... )
314
343
expand_singleton_to_array (asize, a:: AbstractArray ) = a
0 commit comments