Skip to content

Commit 834901e

Browse files
authored
Simplify arithmetic of NotImplemented and treat NoTangent like ZeroTangent (#477)
* Simplify arithmetic for `NotImplemented` and treat `NoTangent` like `ZeroTangent` * Bump version * Add comment about subtraction * Simplify tests
1 parent 187f9fd commit 834901e

File tree

3 files changed

+62
-103
lines changed

3 files changed

+62
-103
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "ChainRulesCore"
22
uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
3-
version = "1.7.1"
3+
version = "1.7.2"
44

55
[deps]
66
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"

src/tangent_arithmetic.jl

Lines changed: 22 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -16,54 +16,37 @@ Notice:
1616

1717
# we propagate `NotImplemented` (e.g., in `@scalar_rule`)
1818
# this requires the following definitions (see also #337)
19-
Base.:+(x::NotImplemented, ::ZeroTangent) = x
20-
Base.:+(::ZeroTangent, x::NotImplemented) = x
2119
Base.:+(x::NotImplemented, ::NotImplemented) = x
22-
Base.:*(::NotImplemented, ::ZeroTangent) = ZeroTangent()
23-
Base.:*(::ZeroTangent, ::NotImplemented) = ZeroTangent()
24-
for T in (:NoTangent, :AbstractThunk, :Tangent, :Any)
20+
Base.:*(x::NotImplemented, ::NotImplemented) = x
21+
LinearAlgebra.dot(x::NotImplemented, ::NotImplemented) = x
22+
# `NotImplemented` always "wins" +
23+
for T in (:ZeroTangent, :NoTangent, :AbstractThunk, :Tangent, :Any)
2524
@eval Base.:+(x::NotImplemented, ::$T) = x
2625
@eval Base.:+(::$T, x::NotImplemented) = x
26+
end
27+
# `NotImplemented` "loses" * and dot against NoTangent and ZeroTangent
28+
# this can be used to ignore partial derivatives that are not implemented
29+
for T in (:ZeroTangent, :NoTangent)
30+
@eval Base.:*(::NotImplemented, ::$T) = $T()
31+
@eval Base.:*(::$T, ::NotImplemented) = $T()
32+
@eval LinearAlgebra.dot(::NotImplemented, ::$T) = $T()
33+
@eval LinearAlgebra.dot(::$T, ::NotImplemented) = $T()
34+
end
35+
# `NotImplemented` "wins" * and dot for other types
36+
for T in (:AbstractThunk, :Tangent, :Any)
2737
@eval Base.:*(x::NotImplemented, ::$T) = x
38+
@eval Base.:*(::$T, x::NotImplemented) = x
39+
@eval LinearAlgebra.dot(x::NotImplemented, ::$T) = x
40+
@eval LinearAlgebra.dot(::$T, x::NotImplemented) = x
2841
end
29-
Base.muladd(x::NotImplemented, y, z) = x
30-
Base.muladd(::NotImplemented, ::ZeroTangent, z) = z
31-
Base.muladd(x::NotImplemented, y, ::ZeroTangent) = x
32-
Base.muladd(::NotImplemented, ::ZeroTangent, ::ZeroTangent) = ZeroTangent()
33-
Base.muladd(x, y::NotImplemented, z) = y
34-
Base.muladd(::ZeroTangent, ::NotImplemented, z) = z
35-
Base.muladd(x, y::NotImplemented, ::ZeroTangent) = y
36-
Base.muladd(::ZeroTangent, ::NotImplemented, ::ZeroTangent) = ZeroTangent()
37-
Base.muladd(x, y, z::NotImplemented) = z
38-
Base.muladd(::ZeroTangent, y, z::NotImplemented) = z
39-
Base.muladd(x, ::ZeroTangent, z::NotImplemented) = z
40-
Base.muladd(::ZeroTangent, ::ZeroTangent, z::NotImplemented) = z
41-
Base.muladd(x::NotImplemented, ::NotImplemented, z) = x
42-
Base.muladd(x::NotImplemented, ::NotImplemented, ::ZeroTangent) = x
43-
Base.muladd(x::NotImplemented, y, ::NotImplemented) = x
44-
Base.muladd(::NotImplemented, ::ZeroTangent, z::NotImplemented) = z
45-
Base.muladd(x, y::NotImplemented, ::NotImplemented) = y
46-
Base.muladd(::ZeroTangent, ::NotImplemented, z::NotImplemented) = z
47-
Base.muladd(x::NotImplemented, ::NotImplemented, ::NotImplemented) = x
48-
LinearAlgebra.dot(::NotImplemented, ::ZeroTangent) = ZeroTangent()
49-
LinearAlgebra.dot(::ZeroTangent, ::NotImplemented) = ZeroTangent()
50-
51-
# other common operations throw an exception
52-
Base.:+(x::NotImplemented) = throw(NotImplementedException(x))
42+
43+
# subtraction throws an exception: in AD we add tangents but do not subtract them
44+
# subtraction happens eg. in gradient descent which can't be performed with `NotImplemented`
5345
Base.:-(x::NotImplemented) = throw(NotImplementedException(x))
54-
Base.:-(x::NotImplemented, ::ZeroTangent) = throw(NotImplementedException(x))
55-
Base.:-(::ZeroTangent, x::NotImplemented) = throw(NotImplementedException(x))
5646
Base.:-(x::NotImplemented, ::NotImplemented) = throw(NotImplementedException(x))
57-
Base.:*(x::NotImplemented, ::NotImplemented) = throw(NotImplementedException(x))
58-
function LinearAlgebra.dot(x::NotImplemented, ::NotImplemented)
59-
return throw(NotImplementedException(x))
60-
end
61-
for T in (:NoTangent, :AbstractThunk, :Tangent, :Any)
47+
for T in (:ZeroTangent, :NoTangent, :AbstractThunk, :Tangent, :Any)
6248
@eval Base.:-(x::NotImplemented, ::$T) = throw(NotImplementedException(x))
6349
@eval Base.:-(::$T, x::NotImplemented) = throw(NotImplementedException(x))
64-
@eval Base.:*(::$T, x::NotImplemented) = throw(NotImplementedException(x))
65-
@eval LinearAlgebra.dot(x::NotImplemented, ::$T) = throw(NotImplementedException(x))
66-
@eval LinearAlgebra.dot(::$T, x::NotImplemented) = throw(NotImplementedException(x))
6750
end
6851

6952
Base.:+(::NoTangent, ::NoTangent) = NoTangent()

test/tangent_types/notimplemented.jl

Lines changed: 39 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -6,78 +6,54 @@
66
ni2 = ChainRulesCore.NotImplemented(
77
@__MODULE__, LineNumberNode(@__LINE__, @__FILE__), "error2"
88
)
9+
x = rand()
10+
thunk = @thunk(x^2)
911

10-
# supported operations (for `@scalar_rule`)
11-
x, y, z = rand(3)
12+
# conjugate
1213
@test conj(ni) === ni
13-
@test muladd(ni, y, z) === ni
14-
@test muladd(ni, ZeroTangent(), z) == z
15-
@test muladd(ni, y, ZeroTangent()) === ni
16-
@test muladd(ni, ZeroTangent(), ZeroTangent()) == ZeroTangent()
17-
@test muladd(ni, ni2, z) === ni
18-
@test muladd(ni, ni2, ZeroTangent()) === ni
19-
@test muladd(ni, y, ni2) === ni
20-
@test muladd(ni, ZeroTangent(), ni2) === ni2
21-
@test muladd(x, ni, z) === ni
22-
@test muladd(ZeroTangent(), ni, z) == z
23-
@test muladd(x, ni, ZeroTangent()) === ni
24-
@test muladd(ZeroTangent(), ni, ZeroTangent()) == ZeroTangent()
25-
@test muladd(x, ni, ni2) === ni
26-
@test muladd(ZeroTangent(), ni, ni2) === ni2
27-
@test muladd(x, y, ni) === ni
28-
@test muladd(ZeroTangent(), y, ni) === ni
29-
@test muladd(x, ZeroTangent(), ni) === ni
30-
@test muladd(ZeroTangent(), ZeroTangent(), ni) === ni
31-
@test ni + rand() === ni
32-
@test ni + ZeroTangent() === ni
33-
@test ni + NoTangent() === ni
34-
@test ni + true === ni
35-
@test ni + @thunk(x^2) === ni
36-
@test rand() + ni === ni
37-
@test ZeroTangent() + ni === ni
38-
@test NoTangent() + ni === ni
39-
@test true + ni === ni
40-
@test @thunk(x^2) + ni === ni
14+
15+
# addition
16+
for a in (true, x, NoTangent(), ZeroTangent(), thunk)
17+
@test ni + a === ni
18+
@test a + ni === ni
19+
end
20+
@test +ni === ni
4121
@test ni + ni2 === ni
42-
@test ni * rand() === ni
43-
@test ni * ZeroTangent() == ZeroTangent()
44-
@test ZeroTangent() * ni == ZeroTangent()
45-
@test dot(ni, ZeroTangent()) == ZeroTangent()
46-
@test dot(ZeroTangent(), ni) == ZeroTangent()
47-
@test ni .* rand() === ni
22+
@test ni2 + ni === ni2
23+
24+
# multiplication and dot product
25+
for a in (true, x, thunk)
26+
@test ni * a === ni
27+
@test a * ni === ni
28+
@test dot(ni, a) === ni
29+
@test dot(a, ni) === ni
30+
end
31+
for a in (NoTangent(), ZeroTangent())
32+
@test ni * a === a
33+
@test a * ni === a
34+
@test dot(ni, a) === a
35+
@test dot(a, ni) === a
36+
end
37+
@test ni * ni2 === ni
38+
@test ni2 * ni === ni2
39+
@test dot(ni, ni2) === ni
40+
@test dot(ni2, ni) === ni2
41+
42+
# broadcasting
43+
@test ni .* x === ni
44+
@test x .* ni === ni
4845
@test broadcastable(ni) isa Ref{typeof(ni)}
4946

5047
# unsupported operations
5148
E = ChainRulesCore.NotImplementedException
52-
@test_throws E +ni
5349
@test_throws E -ni
54-
@test_throws E ni - rand()
55-
@test_throws E ni - ZeroTangent()
56-
@test_throws E ni - NoTangent()
57-
@test_throws E ni - true
58-
@test_throws E ni - @thunk(x^2)
59-
@test_throws E rand() - ni
60-
@test_throws E ZeroTangent() - ni
61-
@test_throws E NoTangent() - ni
62-
@test_throws E true - ni
63-
@test_throws E @thunk(x^2) - ni
50+
for a in (true, x, NoTangent(), ZeroTangent(), thunk)
51+
@test_throws E ni - a
52+
@test_throws E a - ni
53+
end
6454
@test_throws E ni - ni2
65-
@test_throws E rand() * ni
66-
@test_throws E NoTangent() * ni
67-
@test_throws E true * ni
68-
@test_throws E @thunk(x^2) * ni
69-
@test_throws E ni * ni2
70-
@test_throws E dot(ni, rand())
71-
@test_throws E dot(ni, NoTangent())
72-
@test_throws E dot(ni, true)
73-
@test_throws E dot(ni, @thunk(x^2))
74-
@test_throws E dot(rand(), ni)
75-
@test_throws E dot(NoTangent(), ni)
76-
@test_throws E dot(true, ni)
77-
@test_throws E dot(@thunk(x^2), ni)
78-
@test_throws E dot(ni, ni2)
79-
@test_throws E ni / rand()
80-
@test_throws E rand() / ni
55+
@test_throws E ni / x
56+
@test_throws E x / ni
8157
@test_throws E ni / ni2
8258
@test_throws E zero(ni)
8359
@test_throws E zero(typeof(ni))

0 commit comments

Comments
 (0)