@@ -20,7 +20,7 @@ Base.:+(x::NotImplemented, ::NotImplemented) = x
20
20
Base.:* (x:: NotImplemented , :: NotImplemented ) = x
21
21
LinearAlgebra. dot (x:: NotImplemented , :: NotImplemented ) = x
22
22
# `NotImplemented` always "wins" +
23
- for T in (:ZeroTangent , :NoTangent , :AbstractThunk , :Tangent , :Any )
23
+ for T in (:ZeroTangent , :NoTangent , :AbstractThunk , :StructuralTangent , :Any )
24
24
@eval Base.:+ (x:: NotImplemented , :: $T ) = x
25
25
@eval Base.:+ (:: $T , x:: NotImplemented ) = x
26
26
end
@@ -33,7 +33,7 @@ for T in (:ZeroTangent, :NoTangent)
33
33
@eval LinearAlgebra. dot (:: $T , :: NotImplemented ) = $ T ()
34
34
end
35
35
# `NotImplemented` "wins" * and dot for other types
36
- for T in (:AbstractThunk , :Tangent , :Any )
36
+ for T in (:AbstractThunk , :StructuralTangent , :Any )
37
37
@eval Base.:* (x:: NotImplemented , :: $T ) = x
38
38
@eval Base.:* (:: $T , x:: NotImplemented ) = x
39
39
@eval LinearAlgebra. dot (x:: NotImplemented , :: $T ) = x
@@ -55,7 +55,7 @@ Base.:-(::NoTangent, ::NoTangent) = NoTangent()
55
55
Base.:- (:: NoTangent ) = NoTangent ()
56
56
Base.:* (:: NoTangent , :: NoTangent ) = NoTangent ()
57
57
LinearAlgebra. dot (:: NoTangent , :: NoTangent ) = NoTangent ()
58
- for T in (:AbstractThunk , :Tangent , :Any )
58
+ for T in (:AbstractThunk , :StructuralTangent , :Any )
59
59
@eval Base.:+ (:: NoTangent , b:: $T ) = b
60
60
@eval Base.:+ (a:: $T , :: NoTangent ) = a
61
61
@eval Base.:- (:: NoTangent , b:: $T ) = - b
@@ -95,7 +95,7 @@ Base.:-(::ZeroTangent, ::ZeroTangent) = ZeroTangent()
95
95
Base.:- (:: ZeroTangent ) = ZeroTangent ()
96
96
Base.:* (:: ZeroTangent , :: ZeroTangent ) = ZeroTangent ()
97
97
LinearAlgebra. dot (:: ZeroTangent , :: ZeroTangent ) = ZeroTangent ()
98
- for T in (:AbstractThunk , :Tangent , :Any )
98
+ for T in (:AbstractThunk , :StructuralTangent , :Any )
99
99
@eval Base.:+ (:: ZeroTangent , b:: $T ) = b
100
100
@eval Base.:+ (a:: $T , :: ZeroTangent ) = a
101
101
@eval Base.:- (:: ZeroTangent , b:: $T ) = - b
@@ -126,11 +126,11 @@ for T in (:Tangent, :Any)
126
126
@eval Base.:* (a:: $T , b:: AbstractThunk ) = a * unthunk (b)
127
127
end
128
128
129
- function Base.:+ (a:: Tangent {P} , b:: Tangent {P} ) where {P}
129
+ function Base.:+ (a:: StructuralTangent {P} , b:: StructuralTangent {P} ) where {P}
130
130
data = elementwise_add (backing (a), backing (b))
131
- return Tangent {P,typeof(data) } (data)
131
+ return StructuralTangent {P } (data)
132
132
end
133
- function Base.:+ (a:: P , d:: Tangent {P} ) where {P}
133
+ function Base.:+ (a:: P , d:: StructuralTangent {P} ) where {P}
134
134
net_backing = elementwise_add (backing (a), backing (d))
135
135
if debug_mode ()
136
136
try
@@ -143,14 +143,14 @@ function Base.:+(a::P, d::Tangent{P}) where {P}
143
143
end
144
144
end
145
145
Base.:+ (a:: Dict , d:: Tangent{P} ) where {P} = merge (+ , a, backing (d))
146
- Base.:+ (a:: Tangent {P} , b:: P ) where {P} = b + a
146
+ Base.:+ (a:: StructuralTangent {P} , b:: P ) where {P} = b + a
147
147
148
- Base.:- (tangent:: Tangent {P} ) where {P} = map (- , tangent)
148
+ Base.:- (tangent:: StructuralTangent {P} ) where {P} = map (- , tangent)
149
149
150
150
# We intentionally do not define, `Base.*(::Tangent, ::Tangent)` as that is not meaningful
151
151
# In general one doesn't have to represent multiplications of 2 tangents
152
152
# Only of a tangent and a scaling factor (generally `Real`)
153
153
for T in (:Number ,)
154
- @eval Base.:* (s:: $T , tangent:: Tangent ) = map (x -> s * x, tangent)
155
- @eval Base.:* (tangent:: Tangent , s:: $T ) = map (x -> x * s, tangent)
154
+ @eval Base.:* (s:: $T , tangent:: StructuralTangent ) = map (x -> s * x, tangent)
155
+ @eval Base.:* (tangent:: StructuralTangent , s:: $T ) = map (x -> x * s, tangent)
156
156
end
0 commit comments