@@ -74,11 +74,8 @@ struct TaylorTangentIndex <: TangentIndex
74
74
i:: Int
75
75
end
76
76
77
- function Base. getindex (a:: AbstractTangentBundle , b:: TaylorTangentIndex )
78
- error (" $(typeof (a)) is not taylor-like. Taylor indexing is ambiguous" )
79
- end
80
-
81
77
abstract type AbstractTangentSpace; end
78
+ Base.:(== )(x:: AbstractTangentSpace , y:: AbstractTangentSpace ) = == (promote (x, y)... )
82
79
83
80
"""
84
81
struct ExplicitTangent{P}
@@ -89,13 +86,23 @@ represented by a vector of `2^N-1` partials.
89
86
struct ExplicitTangent{P <: Tuple } <: AbstractTangentSpace
90
87
partials:: P
91
88
end
89
+ Base.:(== )(a:: ExplicitTangent , b:: ExplicitTangent ) = a. partials == b. partials
90
+ Base. hash (tt:: ExplicitTangent , h:: UInt64 ) = hash (tt. partials, h)
91
+
92
+ Base. getindex (tangent:: ExplicitTangent , b:: CanonicalTangentIndex ) = tangent. partials[b. i]
93
+ function Base. getindex (tangent:: ExplicitTangent , b:: TaylorTangentIndex )
94
+ if lastindex (tangent. partials) == exp2 (b. i) - 1
95
+ return tangent. partials[end ]
96
+ end
97
+ # TODO : should we also allow other indexes if all the partials at that level are equal up regardless of order?
98
+ throw (DomainError (b, " $(typeof (tangent)) is not taylor-like. Taylor indexing is ambiguous" ))
99
+ end
100
+
92
101
93
102
@eval struct TaylorTangent{C <: Tuple } <: AbstractTangentSpace
94
103
coeffs:: C
95
104
TaylorTangent (coeffs) = $ (Expr (:new , :(TaylorTangent{typeof (coeffs)}), :coeffs ))
96
105
end
97
- Base.:(== )(a:: TaylorTangent , b:: TaylorTangent ) = a. coeffs == b. coeffs
98
- Base. hash (tt:: TaylorTangent , h:: UInt64 ) = hash (tt. coeffs, h)
99
106
100
107
"""
101
108
struct TaylorTangent{C}
@@ -122,6 +129,14 @@ by analogy with the (truncated) Taylor series
122
129
"""
123
130
TaylorTangent
124
131
132
+ Base.:(== )(a:: TaylorTangent , b:: TaylorTangent ) = a. coeffs == b. coeffs
133
+ Base. hash (tt:: TaylorTangent , h:: UInt64 ) = hash (tt. coeffs, h)
134
+
135
+
136
+ Base. getindex (tangent:: TaylorTangent , tti:: TaylorTangentIndex ) = tangent. coeffs[tti. i]
137
+ Base. getindex (tangent:: TaylorTangent , tti:: CanonicalTangentIndex ) = tangent. coeffs[count_ones (tti. i)]
138
+
139
+
125
140
"""
126
141
struct ProductTangent{T <: Tuple{Vararg{AbstractTangentSpace}}}
127
142
@@ -141,6 +156,28 @@ useful for representing singleton values.
141
156
struct UniformTangent{U} <: AbstractTangentSpace
142
157
val:: U
143
158
end
159
+ Base. hash (t:: UniformTangent , h:: UInt64 ) = hash (t. val, h)
160
+ Base.:(== )(t1:: UniformTangent , t2:: UniformTangent ) = t1. val == t2. val
161
+
162
+ Base. getindex (tangent:: UniformTangent , :: Any ) = tangent. val
163
+
164
+ # Conversion and promotion
165
+ Base. promote_rule (et:: Type{<:ExplicitTangent} , :: Type{<:AbstractTangentSpace} ) = et
166
+ Base. promote_rule (tt:: Type{<:TaylorTangent} , :: Type{<:AbstractTangentSpace} ) = tt
167
+ Base. promote_rule (et:: Type{<:ExplicitTangent} , :: Type{<:TaylorTangent} ) = et
168
+ Base. promote_rule (:: Type{<:TaylorTangent} , et:: Type{<:ExplicitTangent} ) = et
169
+
170
+ num_partials (:: Type{TaylorTangent{P}} ) where P = fieldcount (P)
171
+ num_partials (:: Type{ExplicitTangent{P}} ) where P = fieldcount (P)
172
+ Base. eltype (:: Type{TaylorTangent{P}} ) where P = eltype (P)
173
+ Base. eltype (:: Type{ExplicitTangent{P}} ) where P = eltype (P)
174
+ function Base. convert (:: Type{T} , ut:: UniformTangent ) where {T<: Union{TaylorTangent, ExplicitTangent} }
175
+ # can't just use T to construct as the inner constructor doesn't accept type params. So get T_wrapper
176
+ T_wrapper = T<: TaylorTangent ? TaylorTangent : ExplicitTangent
177
+ T_wrapper (ntuple (_-> convert (eltype (T), ut. val), num_partials (T)))
178
+ end
179
+ Base. convert (T:: Type{<:ExplicitTangent} , tt:: TaylorTangent ) = ExplicitTangent (ntuple (i-> tt[CanonicalTangentIndex (i)], num_partials (T)))
180
+ # TODO : Should we define the reverse: Explict->Taylor for the cases where that is actually defined?
144
181
145
182
function _TangentBundle end
146
183
@@ -162,7 +199,9 @@ TangentBundle{N}(primal::B, tangent::P) where {N, B, P<:AbstractTangentSpace} =
162
199
_TangentBundle (Val {N} (), primal, tangent)
163
200
164
201
Base. hash (tb:: TangentBundle , h:: UInt64 ) = hash (tb. primal, h)
165
- Base.:(== )(a:: TangentBundle , b:: TangentBundle ) = (a. primal == b. primal) && (a. tangent == b. tangent)
202
+ Base.:(== )(a:: TangentBundle , b:: TangentBundle ) = false # different orders
203
+ Base.:(== )(a:: TangentBundle{N} , b:: TangentBundle{N} ) where {N} = (a. primal == b. primal) && (a. tangent == b. tangent)
204
+ Base. getindex (tbun:: TangentBundle , x) = getindex (tbun. tangent, x)
166
205
167
206
const ExplicitTangentBundle{N, B, P} = TangentBundle{N, B, ExplicitTangent{P}}
168
207
@@ -197,12 +236,7 @@ function Base.show(io::IO, x::ExplicitTangentBundle)
197
236
length (x. partials) >= 7 && print (io, " + " , x. partials[7 ], " ∂₁ ∂₂ ∂₃" )
198
237
end
199
238
200
- function Base. getindex (a:: ExplicitTangentBundle{N} , b:: TaylorTangentIndex ) where {N}
201
- if b. i === N
202
- return a. tangent. partials[end ]
203
- end
204
- error (" $(typeof (a)) is not taylor-like. Taylor indexing is ambiguous" )
205
- end
239
+
206
240
207
241
const TaylorBundle{N, B, P} = TangentBundle{N, B, TaylorTangent{P}}
208
242
@@ -233,11 +267,6 @@ function Base.show(io::IO, x::TaylorBundle{1})
233
267
print (io, x. coeffs[1 ], " ∂₁" )
234
268
end
235
269
236
- Base. getindex (tb:: TaylorBundle , tti:: TaylorTangentIndex ) = tb. tangent. coeffs[tti. i]
237
- function Base. getindex (tb:: TaylorBundle , tti:: CanonicalTangentIndex )
238
- tb. tangent. coeffs[count_ones (tti. i)]
239
- end
240
-
241
270
" for a TaylorTangent{N, <:Tuple} this breaks it up unto 1 TaylorTangent{N} for each element of the primal tuple"
242
271
function destructure (r:: TaylorBundle{N, B} ) where {N, B<: Tuple }
243
272
return ntuple (fieldcount (B)) do field_ii
@@ -307,8 +336,18 @@ function Base.show(io::IO, t::AbstractZeroBundle{N}) where N
307
336
print (io, " )" )
308
337
end
309
338
339
+ # Conversion and promotion
340
+ function Base. promote_rule (:: Type{TangentBundle{N, B, P1}} , :: Type{TangentBundle{N, B, P2}} ) where {N,B,P1,P2}
341
+ return TangentBundle{N, B, promote_type (P1, P2)}
342
+ end
343
+
344
+ function Base. convert (:: Type{T} , tbun:: TangentBundle{N, B} ) where {N, B, P, T<: TangentBundle{N,B,P} }
345
+ the_primal = convert (B, primal (tbun))
346
+ the_partials = convert (P, tbun. tangent)
347
+ return _TangentBundle (Val {N} (), the_primal, the_partials)
348
+ end
310
349
311
- Base . getindex (u :: UniformBundle , :: TaylorTangentIndex ) = u . tangent . val
350
+ # StructureArrays helpers
312
351
313
352
expand_singleton_to_array (asize, a:: AbstractZero ) = fill (a, asize... )
314
353
expand_singleton_to_array (asize, a:: AbstractArray ) = a
0 commit comments