Skip to content

Commit c84bfab

Browse files
authored
Merge pull request #502 from JuliaDiff/mz/-Tangent
Implement `-Tangent`
2 parents 5025704 + 43943d8 commit c84bfab

File tree

4 files changed

+12
-3
lines changed

4 files changed

+12
-3
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "ChainRulesCore"
22
uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
3-
version = "1.10.2"
3+
version = "1.11.0"
44

55
[deps]
66
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"

src/tangent_arithmetic.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,10 +39,11 @@ for T in (:AbstractThunk, :Tangent, :Any)
3939
@eval LinearAlgebra.dot(x::NotImplemented, ::$T) = x
4040
@eval LinearAlgebra.dot(::$T, x::NotImplemented) = x
4141
end
42+
# unary :- is the same as multiplication by -1
43+
Base.:-(x::NotImplemented) = x
4244

4345
# subtraction throws an exception: in AD we add tangents but do not subtract them
4446
# subtraction happens eg. in gradient descent which can't be performed with `NotImplemented`
45-
Base.:-(x::NotImplemented) = throw(NotImplementedException(x))
4647
Base.:-(x::NotImplemented, ::NotImplemented) = throw(NotImplementedException(x))
4748
for T in (:ZeroTangent, :NoTangent, :AbstractThunk, :Tangent, :Any)
4849
@eval Base.:-(x::NotImplemented, ::$T) = throw(NotImplementedException(x))
@@ -144,6 +145,8 @@ end
144145
Base.:+(a::Dict, d::Tangent{P}) where {P} = merge(+, a, backing(d))
145146
Base.:+(a::Tangent{P}, b::P) where {P} = b + a
146147

148+
Base.:-(tangent::Tangent{P}) where {P} = map(-, tangent)
149+
147150
# We intentionally do not define, `Base.*(::Tangent, ::Tangent)` as that is not meaningful
148151
# In general one doesn't have to represent multiplications of 2 differentials
149152
# Only of a differential and a scaling factor (generally `Real`)

test/tangent_types/notimplemented.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
@test ni2 + ni === ni2
2323

2424
# multiplication and dot product
25+
@test -ni == ni
2526
for a in (true, x, thunk)
2627
@test ni * a === ni
2728
@test a * ni === ni
@@ -46,7 +47,6 @@
4647

4748
# unsupported operations
4849
E = ChainRulesCore.NotImplementedException
49-
@test_throws E -ni
5050
for a in (true, x, NoTangent(), ZeroTangent(), thunk)
5151
@test_throws E ni - a
5252
@test_throws E a - ni

test/tangent_types/tangent.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -321,6 +321,12 @@ end
321321
@test c * t == c * 2
322322
end
323323

324+
@testset "-Tangent" begin
325+
t = Tangent{Foo}(; x=1.0, y=-2.0)
326+
@test -t == Tangent{Foo}(; x=-1.0, y=2.0)
327+
@test -1.0 * t == -t
328+
end
329+
324330
@testset "scaling" begin
325331
@test (
326332
2 * Tangent{Foo}(; y=1.5, x=2.5) ==

0 commit comments

Comments
 (0)