Skip to content

Commit 6767633

Browse files
committed
composite -> tangent
1 parent f2a4061 commit 6767633

File tree

3 files changed

+54
-54
lines changed

3 files changed

+54
-54
lines changed

src/tangent_arithmetic.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,6 @@ Base.:+(a::Tangent{P}, b::P) where P = b + a
165165
# In general one doesn't have to represent multiplications of 2 differentials
166166
# Only of a differential and a scaling factor (generally `Real`)
167167
for T in (:Any,)
168-
@eval Base.:*(s::$T, comp::Tangent) = map(x->s*x, comp)
169-
@eval Base.:*(comp::Tangent, s::$T) = map(x->x*s, comp)
168+
@eval Base.:*(s::$T, tangent::Tangent) = map(x->s*x, tangent)
169+
@eval Base.:*(tangent::Tangent, s::$T) = map(x->x*s, tangent)
170170
end

src/tangent_types/tangent.jl

Lines changed: 47 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ It should not be passed in by user.
1616
For `Tangent`s of `Tuple`s, `iterate` and `getindex` are overloaded to behave similarly
1717
to for a tuple.
1818
For `Tangent`s of `struct`s, `getproperty` is overloaded to allow for accessing values
19-
via `comp.fieldname`.
19+
via `tangent.fieldname`.
2020
Any fields not explictly present in the `Tangent` are treated as being set to `ZeroTangent()`.
2121
To make a `Tangent` have all the fields of the primal the [`canonicalize`](@ref)
2222
function is provided.
@@ -56,80 +56,80 @@ Base.:(==)(a::Tangent{P}, b::Tangent{Q}) where {P, Q} = false
5656

5757
Base.hash(a::Tangent, h::UInt) = Base.hash(backing(canonicalize(a)), h)
5858

59-
function Base.show(io::IO, comp::Tangent{P}) where P
59+
function Base.show(io::IO, tangent::Tangent{P}) where P
6060
print(io, "Tangent{")
6161
show(io, P)
6262
print(io, "}")
63-
if isempty(backing(comp))
63+
if isempty(backing(tangent))
6464
print(io, "()") # so it doesn't show `NamedTuple()`
6565
else
6666
# allow Tuple or NamedTuple `show` to do the rendering of brackets etc
67-
show(io, backing(comp))
67+
show(io, backing(tangent))
6868
end
6969
end
7070

71-
function Base.getindex(comp::Tangent{P, T}, idx::Int) where {P, T<:Union{Tuple, NamedTuple}}
72-
back = backing(canonicalize(comp))
71+
function Base.getindex(tangent::Tangent{P, T}, idx::Int) where {P, T<:Union{Tuple, NamedTuple}}
72+
back = backing(canonicalize(tangent))
7373
return unthunk(getfield(back, idx))
7474
end
75-
function Base.getindex(comp::Tangent{P, T}, idx::Symbol) where {P, T<:NamedTuple}
75+
function Base.getindex(tangent::Tangent{P, T}, idx::Symbol) where {P, T<:NamedTuple}
7676
hasfield(T, idx) || return ZeroTangent()
77-
return unthunk(getfield(backing(comp), idx))
77+
return unthunk(getfield(backing(tangent), idx))
7878
end
79-
function Base.getindex(comp::Tangent, idx) where {P, T<:AbstractDict}
80-
return unthunk(getindex(backing(comp), idx))
79+
function Base.getindex(tangent::Tangent, idx) where {P, T<:AbstractDict}
80+
return unthunk(getindex(backing(tangent), idx))
8181
end
8282

83-
function Base.getproperty(comp::Tangent, idx::Int)
84-
back = backing(canonicalize(comp))
83+
function Base.getproperty(tangent::Tangent, idx::Int)
84+
back = backing(canonicalize(tangent))
8585
return unthunk(getfield(back, idx))
8686
end
87-
function Base.getproperty(comp::Tangent{P, T}, idx::Symbol) where {P, T<:NamedTuple}
87+
function Base.getproperty(tangent::Tangent{P, T}, idx::Symbol) where {P, T<:NamedTuple}
8888
hasfield(T, idx) || return ZeroTangent()
89-
return unthunk(getfield(backing(comp), idx))
89+
return unthunk(getfield(backing(tangent), idx))
9090
end
9191

92-
Base.keys(comp::Tangent) = keys(backing(comp))
93-
Base.propertynames(comp::Tangent) = propertynames(backing(comp))
92+
Base.keys(tangent::Tangent) = keys(backing(tangent))
93+
Base.propertynames(tangent::Tangent) = propertynames(backing(tangent))
9494

95-
Base.haskey(comp::Tangent, key) = haskey(backing(comp), key)
95+
Base.haskey(tangent::Tangent, key) = haskey(backing(tangent), key)
9696
if isdefined(Base, :hasproperty)
97-
Base.hasproperty(comp::Tangent, key::Symbol) = hasproperty(backing(comp), key)
97+
Base.hasproperty(tangent::Tangent, key::Symbol) = hasproperty(backing(tangent), key)
9898
end
9999

100-
Base.iterate(comp::Tangent, args...) = iterate(backing(comp), args...)
101-
Base.length(comp::Tangent) = length(backing(comp))
100+
Base.iterate(tangent::Tangent, args...) = iterate(backing(tangent), args...)
101+
Base.length(tangent::Tangent) = length(backing(tangent))
102102
Base.eltype(::Type{<:Tangent{<:Any, T}}) where T = eltype(T)
103-
function Base.reverse(comp::Tangent)
104-
rev_backing = reverse(backing(comp))
103+
function Base.reverse(tangent::Tangent)
104+
rev_backing = reverse(backing(tangent))
105105
Tangent{typeof(rev_backing), typeof(rev_backing)}(rev_backing)
106106
end
107107

108-
function Base.indexed_iterate(comp::Tangent{P,<:Tuple}, i::Int, state=1) where {P}
109-
return Base.indexed_iterate(backing(comp), i, state)
108+
function Base.indexed_iterate(tangent::Tangent{P,<:Tuple}, i::Int, state=1) where {P}
109+
return Base.indexed_iterate(backing(tangent), i, state)
110110
end
111111

112-
function Base.map(f, comp::Tangent{P, <:Tuple}) where P
113-
vals::Tuple = map(f, backing(comp))
112+
function Base.map(f, tangent::Tangent{P, <:Tuple}) where P
113+
vals::Tuple = map(f, backing(tangent))
114114
return Tangent{P, typeof(vals)}(vals)
115115
end
116-
function Base.map(f, comp::Tangent{P, <:NamedTuple{L}}) where{P, L}
117-
vals = map(f, Tuple(backing(comp)))
116+
function Base.map(f, tangent::Tangent{P, <:NamedTuple{L}}) where{P, L}
117+
vals = map(f, Tuple(backing(tangent)))
118118
named_vals = NamedTuple{L, typeof(vals)}(vals)
119119
return Tangent{P, typeof(named_vals)}(named_vals)
120120
end
121-
function Base.map(f, comp::Tangent{P, <:Dict}) where {P<:Dict}
122-
return Tangent{P}(Dict(k => f(v) for (k, v) in backing(comp)))
121+
function Base.map(f, tangent::Tangent{P, <:Dict}) where {P<:Dict}
122+
return Tangent{P}(Dict(k => f(v) for (k, v) in backing(tangent)))
123123
end
124124

125-
Base.conj(comp::Tangent) = map(conj, comp)
125+
Base.conj(tangent::Tangent) = map(conj, tangent)
126126

127127
"""
128128
backing(x)
129129
130130
Accesses the backing field of a `Tangent`,
131-
or destructures any other composite type into a `NamedTuple`.
132-
Identity function on `Tuple`. and `NamedTuple`s.
131+
or destructures any other struct type into a `NamedTuple`.
132+
Identity function on `Tuple`s and `NamedTuple`s.
133133
134134
This is an internal function used to simplify operations between `Tangent`s and the
135135
primal types.
@@ -145,15 +145,15 @@ function backing(x::T)::NamedTuple where T
145145
# so the first 4 lines of the branchs look the same, but can not be moved out.
146146
# see https://github.com/JuliaLang/julia/issues/34283
147147
if @generated
148-
!isstructtype(T) && throw(DomainError(T, "backing can only be use on composite types"))
148+
!isstructtype(T) && throw(DomainError(T, "backing can only be used on struct types"))
149149
nfields = fieldcount(T)
150150
names = fieldnames(T)
151151
types = fieldtypes(T)
152152

153153
vals = Expr(:tuple, ntuple(ii->:(getfield(x, $ii)), nfields)...)
154154
return :(NamedTuple{$names, Tuple{$(types...)}}($vals))
155155
else
156-
!isstructtype(T) && throw(DomainError(T, "backing can only be use on composite types"))
156+
!isstructtype(T) && throw(DomainError(T, "backing can only be used on struct types"))
157157
nfields = fieldcount(T)
158158
names = fieldnames(T)
159159
types = fieldtypes(T)
@@ -164,15 +164,15 @@ function backing(x::T)::NamedTuple where T
164164
end
165165

166166
"""
167-
canonicalize(comp::Tangent{P}) -> Tangent{P}
167+
canonicalize(tangent::Tangent{P}) -> Tangent{P}
168168
169169
Return the canonical `Tangent` for the primal type `P`.
170170
The property names of the returned `Tangent` match the field names of the primal,
171-
and all fields of `P` not present in the input `comp` are explictly set to `ZeroTangent()`.
171+
and all fields of `P` not present in the input `tangent` are explictly set to `ZeroTangent()`.
172172
"""
173-
function canonicalize(comp::Tangent{P, <:NamedTuple{L}}) where {P,L}
173+
function canonicalize(tangent::Tangent{P, <:NamedTuple{L}}) where {P,L}
174174
nil = _zeroed_backing(P)
175-
combined = merge(nil, backing(comp))
175+
combined = merge(nil, backing(tangent))
176176
if length(combined) !== fieldcount(P)
177177
throw(ArgumentError(
178178
"Tangent fields do not match primal fields.\n" *
@@ -182,17 +182,17 @@ function canonicalize(comp::Tangent{P, <:NamedTuple{L}}) where {P,L}
182182
return Tangent{P, typeof(combined)}(combined)
183183
end
184184

185-
# Tuple composites are always in their canonical form
186-
canonicalize(comp::Tangent{<:Tuple, <:Tuple}) = comp
185+
# Tuple tangents are always in their canonical form
186+
canonicalize(tangent::Tangent{<:Tuple, <:Tuple}) = tangent
187187

188-
# Dict composite are always in their canonical form.
189-
canonicalize(comp::Tangent{<:Any, <:AbstractDict}) = comp
188+
# Dict tangents are always in their canonical form.
189+
canonicalize(tangent::Tangent{<:Any, <:AbstractDict}) = tangent
190190

191191
# Tangents of unspecified primal types (indicated by specifying exactly `Any`)
192192
# all combinations of type-params are specified here to avoid ambiguities
193-
canonicalize(comp::Tangent{Any, <:NamedTuple{L}}) where {L} = comp
194-
canonicalize(comp::Tangent{Any, <:Tuple}) where {L} = comp
195-
canonicalize(comp::Tangent{Any, <:AbstractDict}) where {L} = comp
193+
canonicalize(tangent::Tangent{Any, <:NamedTuple{L}}) where {L} = tangent
194+
canonicalize(tangent::Tangent{Any, <:Tuple}) where {L} = tangent
195+
canonicalize(tangent::Tangent{Any, <:AbstractDict}) where {L} = tangent
196196

197197
"""
198198
_zeroed_backing(P)
@@ -213,7 +213,7 @@ Constructs an object of type `T`, with the given fields.
213213
Fields must be correct in name and type, and `T` must have a default constructor.
214214
215215
This internally is called to construct structs of the primal type `T`,
216-
after an operation such as the addition of a primal to a composite.
216+
after an operation such as the addition of a primal to a tangent
217217
218218
It should be overloaded, if `T` does not have a default constructor,
219219
or if `T` needs to maintain some invarients between its fields.

test/tangent_types/tangent.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -86,15 +86,15 @@ end
8686

8787
# Test indexed_iterate
8888
ctup = Tangent{Tuple{Float64,Int64}}(2.0, 3)
89-
_unpack2tuple = function(comp)
90-
a, b = comp
89+
_unpack2tuple = function(tangent)
90+
a, b = tangent
9191
return (a, b)
9292
end
9393
@inferred _unpack2tuple(ctup)
9494
@test _unpack2tuple(ctup) === (2.0, 3)
9595

9696
# Test getproperty is inferrable
97-
_unpacknamedtuple = comp -> (comp.x, comp.y)
97+
_unpacknamedtuple = tangent -> (tangent.x, tangent.y)
9898
if VERSION v"1.2"
9999
@inferred _unpacknamedtuple(Tangent{Foo}(x=2, y=3.0))
100100
@inferred _unpacknamedtuple(Tangent{Foo}(y=3.0))
@@ -111,7 +111,7 @@ end
111111

112112
d = Dict(:x => 1, :y => 2.0)
113113
cdict = Tangent{Foo, typeof(d)}(d)
114-
@test_throws MethodError reverse(Tangent{Foo}())
114+
@test_throws MethodError reverse(Tangent{Foo}())
115115
end
116116

117117
@testset "unset properties" begin
@@ -344,7 +344,7 @@ end
344344
@testset "Internals don't allocate a ton" begin
345345
bk = (; x=1.0, y=2.0)
346346
VERSION >= v"1.5" && @test (@ballocated(ChainRulesCore.construct($Foo, $bk))) <= 32
347-
347+
348348
# weaker version of the above (which should pass on all versions)
349349
@test (@ballocated(ChainRulesCore.construct($Foo, $bk))) <= 48
350350
@test (@ballocated ChainRulesCore.elementwise_add($bk, $bk)) <= 48

0 commit comments

Comments
 (0)