Skip to content

Commit d2727b1

Browse files
committed
move functionality up to StructuralTangent
1 parent 79e63f3 commit d2727b1

File tree

2 files changed

+211
-180
lines changed

2 files changed

+211
-180
lines changed

src/tangent_arithmetic.jl

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ Base.:+(x::NotImplemented, ::NotImplemented) = x
2020
Base.:*(x::NotImplemented, ::NotImplemented) = x
2121
LinearAlgebra.dot(x::NotImplemented, ::NotImplemented) = x
2222
# `NotImplemented` always "wins" +
23-
for T in (:ZeroTangent, :NoTangent, :AbstractThunk, :Tangent, :Any)
23+
for T in (:ZeroTangent, :NoTangent, :AbstractThunk, :StructuralTangent, :Any)
2424
@eval Base.:+(x::NotImplemented, ::$T) = x
2525
@eval Base.:+(::$T, x::NotImplemented) = x
2626
end
@@ -33,7 +33,7 @@ for T in (:ZeroTangent, :NoTangent)
3333
@eval LinearAlgebra.dot(::$T, ::NotImplemented) = $T()
3434
end
3535
# `NotImplemented` "wins" * and dot for other types
36-
for T in (:AbstractThunk, :Tangent, :Any)
36+
for T in (:AbstractThunk, :StructuralTangent, :Any)
3737
@eval Base.:*(x::NotImplemented, ::$T) = x
3838
@eval Base.:*(::$T, x::NotImplemented) = x
3939
@eval LinearAlgebra.dot(x::NotImplemented, ::$T) = x
@@ -55,7 +55,7 @@ Base.:-(::NoTangent, ::NoTangent) = NoTangent()
5555
Base.:-(::NoTangent) = NoTangent()
5656
Base.:*(::NoTangent, ::NoTangent) = NoTangent()
5757
LinearAlgebra.dot(::NoTangent, ::NoTangent) = NoTangent()
58-
for T in (:AbstractThunk, :Tangent, :Any)
58+
for T in (:AbstractThunk, :StructuralTangent, :Any)
5959
@eval Base.:+(::NoTangent, b::$T) = b
6060
@eval Base.:+(a::$T, ::NoTangent) = a
6161
@eval Base.:-(::NoTangent, b::$T) = -b
@@ -95,7 +95,7 @@ Base.:-(::ZeroTangent, ::ZeroTangent) = ZeroTangent()
9595
Base.:-(::ZeroTangent) = ZeroTangent()
9696
Base.:*(::ZeroTangent, ::ZeroTangent) = ZeroTangent()
9797
LinearAlgebra.dot(::ZeroTangent, ::ZeroTangent) = ZeroTangent()
98-
for T in (:AbstractThunk, :Tangent, :Any)
98+
for T in (:AbstractThunk, :StructuralTangent, :Any)
9999
@eval Base.:+(::ZeroTangent, b::$T) = b
100100
@eval Base.:+(a::$T, ::ZeroTangent) = a
101101
@eval Base.:-(::ZeroTangent, b::$T) = -b
@@ -126,11 +126,11 @@ for T in (:Tangent, :Any)
126126
@eval Base.:*(a::$T, b::AbstractThunk) = a * unthunk(b)
127127
end
128128

129-
function Base.:+(a::Tangent{P}, b::Tangent{P}) where {P}
129+
function Base.:+(a::StructuralTangent{P}, b::StructuralTangent{P}) where {P}
130130
data = elementwise_add(backing(a), backing(b))
131-
return Tangent{P,typeof(data)}(data)
131+
return StructuralTangent{P}(data)
132132
end
133-
function Base.:+(a::P, d::Tangent{P}) where {P}
133+
function Base.:+(a::P, d::StructuralTangent{P}) where {P}
134134
net_backing = elementwise_add(backing(a), backing(d))
135135
if debug_mode()
136136
try
@@ -143,14 +143,14 @@ function Base.:+(a::P, d::Tangent{P}) where {P}
143143
end
144144
end
145145
Base.:+(a::Dict, d::Tangent{P}) where {P} = merge(+, a, backing(d))
146-
Base.:+(a::Tangent{P}, b::P) where {P} = b + a
146+
Base.:+(a::StructuralTangent{P}, b::P) where {P} = b + a
147147

148-
Base.:-(tangent::Tangent{P}) where {P} = map(-, tangent)
148+
Base.:-(tangent::StructuralTangent{P}) where {P} = map(-, tangent)
149149

150150
# We intentionally do not define, `Base.*(::Tangent, ::Tangent)` as that is not meaningful
151151
# In general one doesn't have to represent multiplications of 2 tangents
152152
# Only of a tangent and a scaling factor (generally `Real`)
153153
for T in (:Number,)
154-
@eval Base.:*(s::$T, tangent::Tangent) = map(x -> s * x, tangent)
155-
@eval Base.:*(tangent::Tangent, s::$T) = map(x -> x * s, tangent)
154+
@eval Base.:*(s::$T, tangent::StructuralTangent) = map(x -> s * x, tangent)
155+
@eval Base.:*(tangent::StructuralTangent, s::$T) = map(x -> x * s, tangent)
156156
end

0 commit comments

Comments
 (0)