@@ -16,7 +16,7 @@ It should not be passed in by user.
16
16
For `Tangent`s of `Tuple`s, `iterate` and `getindex` are overloaded to behave similarly
17
17
to for a tuple.
18
18
For `Tangent`s of `struct`s, `getproperty` is overloaded to allow for accessing values
19
- via `comp .fieldname`.
19
+ via `tangent .fieldname`.
20
20
Any fields not explictly present in the `Tangent` are treated as being set to `ZeroTangent()`.
21
21
To make a `Tangent` have all the fields of the primal the [`canonicalize`](@ref)
22
22
function is provided.
@@ -56,80 +56,80 @@ Base.:(==)(a::Tangent{P}, b::Tangent{Q}) where {P, Q} = false
56
56
57
57
Base. hash (a:: Tangent , h:: UInt ) = Base. hash (backing (canonicalize (a)), h)
58
58
59
- function Base. show (io:: IO , comp :: Tangent{P} ) where P
59
+ function Base. show (io:: IO , tangent :: Tangent{P} ) where P
60
60
print (io, " Tangent{" )
61
61
show (io, P)
62
62
print (io, " }" )
63
- if isempty (backing (comp ))
63
+ if isempty (backing (tangent ))
64
64
print (io, " ()" ) # so it doesn't show `NamedTuple()`
65
65
else
66
66
# allow Tuple or NamedTuple `show` to do the rendering of brackets etc
67
- show (io, backing (comp ))
67
+ show (io, backing (tangent ))
68
68
end
69
69
end
70
70
71
- function Base. getindex (comp :: Tangent{P, T} , idx:: Int ) where {P, T<: Union{Tuple, NamedTuple} }
72
- back = backing (canonicalize (comp ))
71
+ function Base. getindex (tangent :: Tangent{P, T} , idx:: Int ) where {P, T<: Union{Tuple, NamedTuple} }
72
+ back = backing (canonicalize (tangent ))
73
73
return unthunk (getfield (back, idx))
74
74
end
75
- function Base. getindex (comp :: Tangent{P, T} , idx:: Symbol ) where {P, T<: NamedTuple }
75
+ function Base. getindex (tangent :: Tangent{P, T} , idx:: Symbol ) where {P, T<: NamedTuple }
76
76
hasfield (T, idx) || return ZeroTangent ()
77
- return unthunk (getfield (backing (comp ), idx))
77
+ return unthunk (getfield (backing (tangent ), idx))
78
78
end
79
- function Base. getindex (comp :: Tangent , idx) where {P, T<: AbstractDict }
80
- return unthunk (getindex (backing (comp ), idx))
79
+ function Base. getindex (tangent :: Tangent , idx) where {P, T<: AbstractDict }
80
+ return unthunk (getindex (backing (tangent ), idx))
81
81
end
82
82
83
- function Base. getproperty (comp :: Tangent , idx:: Int )
84
- back = backing (canonicalize (comp ))
83
+ function Base. getproperty (tangent :: Tangent , idx:: Int )
84
+ back = backing (canonicalize (tangent ))
85
85
return unthunk (getfield (back, idx))
86
86
end
87
- function Base. getproperty (comp :: Tangent{P, T} , idx:: Symbol ) where {P, T<: NamedTuple }
87
+ function Base. getproperty (tangent :: Tangent{P, T} , idx:: Symbol ) where {P, T<: NamedTuple }
88
88
hasfield (T, idx) || return ZeroTangent ()
89
- return unthunk (getfield (backing (comp ), idx))
89
+ return unthunk (getfield (backing (tangent ), idx))
90
90
end
91
91
92
- Base. keys (comp :: Tangent ) = keys (backing (comp ))
93
- Base. propertynames (comp :: Tangent ) = propertynames (backing (comp ))
92
+ Base. keys (tangent :: Tangent ) = keys (backing (tangent ))
93
+ Base. propertynames (tangent :: Tangent ) = propertynames (backing (tangent ))
94
94
95
- Base. haskey (comp :: Tangent , key) = haskey (backing (comp ), key)
95
+ Base. haskey (tangent :: Tangent , key) = haskey (backing (tangent ), key)
96
96
if isdefined (Base, :hasproperty )
97
- Base. hasproperty (comp :: Tangent , key:: Symbol ) = hasproperty (backing (comp ), key)
97
+ Base. hasproperty (tangent :: Tangent , key:: Symbol ) = hasproperty (backing (tangent ), key)
98
98
end
99
99
100
- Base. iterate (comp :: Tangent , args... ) = iterate (backing (comp ), args... )
101
- Base. length (comp :: Tangent ) = length (backing (comp ))
100
+ Base. iterate (tangent :: Tangent , args... ) = iterate (backing (tangent ), args... )
101
+ Base. length (tangent :: Tangent ) = length (backing (tangent ))
102
102
Base. eltype (:: Type{<:Tangent{<:Any, T}} ) where T = eltype (T)
103
- function Base. reverse (comp :: Tangent )
104
- rev_backing = reverse (backing (comp ))
103
+ function Base. reverse (tangent :: Tangent )
104
+ rev_backing = reverse (backing (tangent ))
105
105
Tangent {typeof(rev_backing), typeof(rev_backing)} (rev_backing)
106
106
end
107
107
108
- function Base. indexed_iterate (comp :: Tangent{P,<:Tuple} , i:: Int , state= 1 ) where {P}
109
- return Base. indexed_iterate (backing (comp ), i, state)
108
+ function Base. indexed_iterate (tangent :: Tangent{P,<:Tuple} , i:: Int , state= 1 ) where {P}
109
+ return Base. indexed_iterate (backing (tangent ), i, state)
110
110
end
111
111
112
- function Base. map (f, comp :: Tangent{P, <:Tuple} ) where P
113
- vals:: Tuple = map (f, backing (comp ))
112
+ function Base. map (f, tangent :: Tangent{P, <:Tuple} ) where P
113
+ vals:: Tuple = map (f, backing (tangent ))
114
114
return Tangent {P, typeof(vals)} (vals)
115
115
end
116
- function Base. map (f, comp :: Tangent{P, <:NamedTuple{L}} ) where {P, L}
117
- vals = map (f, Tuple (backing (comp )))
116
+ function Base. map (f, tangent :: Tangent{P, <:NamedTuple{L}} ) where {P, L}
117
+ vals = map (f, Tuple (backing (tangent )))
118
118
named_vals = NamedTuple {L, typeof(vals)} (vals)
119
119
return Tangent {P, typeof(named_vals)} (named_vals)
120
120
end
121
- function Base. map (f, comp :: Tangent{P, <:Dict} ) where {P<: Dict }
122
- return Tangent {P} (Dict (k => f (v) for (k, v) in backing (comp )))
121
+ function Base. map (f, tangent :: Tangent{P, <:Dict} ) where {P<: Dict }
122
+ return Tangent {P} (Dict (k => f (v) for (k, v) in backing (tangent )))
123
123
end
124
124
125
- Base. conj (comp :: Tangent ) = map (conj, comp )
125
+ Base. conj (tangent :: Tangent ) = map (conj, tangent )
126
126
127
127
"""
128
128
backing(x)
129
129
130
130
Accesses the backing field of a `Tangent`,
131
- or destructures any other composite type into a `NamedTuple`.
132
- Identity function on `Tuple`. and `NamedTuple`s.
131
+ or destructures any other struct type into a `NamedTuple`.
132
+ Identity function on `Tuple`s and `NamedTuple`s.
133
133
134
134
This is an internal function used to simplify operations between `Tangent`s and the
135
135
primal types.
@@ -145,15 +145,15 @@ function backing(x::T)::NamedTuple where T
145
145
# so the first 4 lines of the branchs look the same, but can not be moved out.
146
146
# see https://github.com/JuliaLang/julia/issues/34283
147
147
if @generated
148
- ! isstructtype (T) && throw (DomainError (T, " backing can only be use on composite types" ))
148
+ ! isstructtype (T) && throw (DomainError (T, " backing can only be used on struct types" ))
149
149
nfields = fieldcount (T)
150
150
names = fieldnames (T)
151
151
types = fieldtypes (T)
152
152
153
153
vals = Expr (:tuple , ntuple (ii-> :(getfield (x, $ ii)), nfields)... )
154
154
return :(NamedTuple {$names, Tuple{$(types...)}} ($ vals))
155
155
else
156
- ! isstructtype (T) && throw (DomainError (T, " backing can only be use on composite types" ))
156
+ ! isstructtype (T) && throw (DomainError (T, " backing can only be used on struct types" ))
157
157
nfields = fieldcount (T)
158
158
names = fieldnames (T)
159
159
types = fieldtypes (T)
@@ -164,15 +164,15 @@ function backing(x::T)::NamedTuple where T
164
164
end
165
165
166
166
"""
167
- canonicalize(comp ::Tangent{P}) -> Tangent{P}
167
+ canonicalize(tangent ::Tangent{P}) -> Tangent{P}
168
168
169
169
Return the canonical `Tangent` for the primal type `P`.
170
170
The property names of the returned `Tangent` match the field names of the primal,
171
- and all fields of `P` not present in the input `comp ` are explictly set to `ZeroTangent()`.
171
+ and all fields of `P` not present in the input `tangent ` are explictly set to `ZeroTangent()`.
172
172
"""
173
- function canonicalize (comp :: Tangent{P, <:NamedTuple{L}} ) where {P,L}
173
+ function canonicalize (tangent :: Tangent{P, <:NamedTuple{L}} ) where {P,L}
174
174
nil = _zeroed_backing (P)
175
- combined = merge (nil, backing (comp ))
175
+ combined = merge (nil, backing (tangent ))
176
176
if length (combined) != = fieldcount (P)
177
177
throw (ArgumentError (
178
178
" Tangent fields do not match primal fields.\n " *
@@ -182,17 +182,17 @@ function canonicalize(comp::Tangent{P, <:NamedTuple{L}}) where {P,L}
182
182
return Tangent {P, typeof(combined)} (combined)
183
183
end
184
184
185
- # Tuple composites are always in their canonical form
186
- canonicalize (comp :: Tangent{<:Tuple, <:Tuple} ) = comp
185
+ # Tuple tangents are always in their canonical form
186
+ canonicalize (tangent :: Tangent{<:Tuple, <:Tuple} ) = tangent
187
187
188
- # Dict composite are always in their canonical form.
189
- canonicalize (comp :: Tangent{<:Any, <:AbstractDict} ) = comp
188
+ # Dict tangents are always in their canonical form.
189
+ canonicalize (tangent :: Tangent{<:Any, <:AbstractDict} ) = tangent
190
190
191
191
# Tangents of unspecified primal types (indicated by specifying exactly `Any`)
192
192
# all combinations of type-params are specified here to avoid ambiguities
193
- canonicalize (comp :: Tangent{Any, <:NamedTuple{L}} ) where {L} = comp
194
- canonicalize (comp :: Tangent{Any, <:Tuple} ) where {L} = comp
195
- canonicalize (comp :: Tangent{Any, <:AbstractDict} ) where {L} = comp
193
+ canonicalize (tangent :: Tangent{Any, <:NamedTuple{L}} ) where {L} = tangent
194
+ canonicalize (tangent :: Tangent{Any, <:Tuple} ) where {L} = tangent
195
+ canonicalize (tangent :: Tangent{Any, <:AbstractDict} ) where {L} = tangent
196
196
197
197
"""
198
198
_zeroed_backing(P)
@@ -213,7 +213,7 @@ Constructs an object of type `T`, with the given fields.
213
213
Fields must be correct in name and type, and `T` must have a default constructor.
214
214
215
215
This internally is called to construct structs of the primal type `T`,
216
- after an operation such as the addition of a primal to a composite.
216
+ after an operation such as the addition of a primal to a tangent
217
217
218
218
It should be overloaded, if `T` does not have a default constructor,
219
219
or if `T` needs to maintain some invarients between its fields.
0 commit comments