From 249b441a728c6d7aeacd4601b20ed25de2ee09c1 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Sun, 3 Oct 2021 01:03:57 +0200 Subject: [PATCH 1/4] Simplify arithmetic for `NotImplemented` and treat `NoTangent` like `ZeroTangent` --- src/tangent_arithmetic.jl | 55 ++++++---------- test/tangent_types/notimplemented.jl | 93 ++++++++++------------------ 2 files changed, 53 insertions(+), 95 deletions(-) diff --git a/src/tangent_arithmetic.jl b/src/tangent_arithmetic.jl index 03c831396..91a427999 100644 --- a/src/tangent_arithmetic.jl +++ b/src/tangent_arithmetic.jl @@ -16,54 +16,37 @@ Notice: # we propagate `NotImplemented` (e.g., in `@scalar_rule`) # this requires the following definitions (see also #337) -Base.:+(x::NotImplemented, ::ZeroTangent) = x -Base.:+(::ZeroTangent, x::NotImplemented) = x Base.:+(x::NotImplemented, ::NotImplemented) = x -Base.:*(::NotImplemented, ::ZeroTangent) = ZeroTangent() -Base.:*(::ZeroTangent, ::NotImplemented) = ZeroTangent() -for T in (:NoTangent, :AbstractThunk, :Tangent, :Any) +Base.:*(x::NotImplemented, ::NotImplemented) = x +LinearAlgebra.dot(x::NotImplemented, ::NotImplemented) = x +# `NotImplemented` always "wins" + +for T in (:ZeroTangent, :NoTangent, :AbstractThunk, :Tangent, :Any) @eval Base.:+(x::NotImplemented, ::$T) = x @eval Base.:+(::$T, x::NotImplemented) = x +end +# `NotImplemented` "loses" * and dot against NoTangent and ZeroTangent +# this can be used to ignore partial derivatives that are not implemented +for T in (:ZeroTangent, :NoTangent) + @eval Base.:*(::NotImplemented, ::$T) = $T() + @eval Base.:*(::$T, ::NotImplemented) = $T() + @eval LinearAlgebra.dot(::NotImplemented, ::$T) = $T() + @eval LinearAlgebra.dot(::$T, ::NotImplemented) = $T() +end +# `NotImplemented` "wins" * and dot for other types +for T in (:AbstractThunk, :Tangent, :Any) @eval Base.:*(x::NotImplemented, ::$T) = x + @eval Base.:*(::$T, x::NotImplemented) = x + @eval LinearAlgebra.dot(x::NotImplemented, ::$T) = x + @eval LinearAlgebra.dot(::$T, x::NotImplemented) = x end -Base.muladd(x::NotImplemented, y, z) = x -Base.muladd(::NotImplemented, ::ZeroTangent, z) = z -Base.muladd(x::NotImplemented, y, ::ZeroTangent) = x -Base.muladd(::NotImplemented, ::ZeroTangent, ::ZeroTangent) = ZeroTangent() -Base.muladd(x, y::NotImplemented, z) = y -Base.muladd(::ZeroTangent, ::NotImplemented, z) = z -Base.muladd(x, y::NotImplemented, ::ZeroTangent) = y -Base.muladd(::ZeroTangent, ::NotImplemented, ::ZeroTangent) = ZeroTangent() -Base.muladd(x, y, z::NotImplemented) = z -Base.muladd(::ZeroTangent, y, z::NotImplemented) = z -Base.muladd(x, ::ZeroTangent, z::NotImplemented) = z -Base.muladd(::ZeroTangent, ::ZeroTangent, z::NotImplemented) = z -Base.muladd(x::NotImplemented, ::NotImplemented, z) = x -Base.muladd(x::NotImplemented, ::NotImplemented, ::ZeroTangent) = x -Base.muladd(x::NotImplemented, y, ::NotImplemented) = x -Base.muladd(::NotImplemented, ::ZeroTangent, z::NotImplemented) = z -Base.muladd(x, y::NotImplemented, ::NotImplemented) = y -Base.muladd(::ZeroTangent, ::NotImplemented, z::NotImplemented) = z -Base.muladd(x::NotImplemented, ::NotImplemented, ::NotImplemented) = x -LinearAlgebra.dot(::NotImplemented, ::ZeroTangent) = ZeroTangent() -LinearAlgebra.dot(::ZeroTangent, ::NotImplemented) = ZeroTangent() # other common operations throw an exception Base.:+(x::NotImplemented) = throw(NotImplementedException(x)) Base.:-(x::NotImplemented) = throw(NotImplementedException(x)) -Base.:-(x::NotImplemented, ::ZeroTangent) = throw(NotImplementedException(x)) -Base.:-(::ZeroTangent, x::NotImplemented) = throw(NotImplementedException(x)) Base.:-(x::NotImplemented, ::NotImplemented) = throw(NotImplementedException(x)) -Base.:*(x::NotImplemented, ::NotImplemented) = throw(NotImplementedException(x)) -function LinearAlgebra.dot(x::NotImplemented, ::NotImplemented) - return throw(NotImplementedException(x)) -end -for T in (:NoTangent, :AbstractThunk, :Tangent, :Any) +for T in (:ZeroTangent, :NoTangent, :AbstractThunk, :Tangent, :Any) @eval Base.:-(x::NotImplemented, ::$T) = throw(NotImplementedException(x)) @eval Base.:-(::$T, x::NotImplemented) = throw(NotImplementedException(x)) - @eval Base.:*(::$T, x::NotImplemented) = throw(NotImplementedException(x)) - @eval LinearAlgebra.dot(x::NotImplemented, ::$T) = throw(NotImplementedException(x)) - @eval LinearAlgebra.dot(::$T, x::NotImplemented) = throw(NotImplementedException(x)) end Base.:+(::NoTangent, ::NoTangent) = NoTangent() diff --git a/test/tangent_types/notimplemented.jl b/test/tangent_types/notimplemented.jl index fe3668518..873cb234f 100644 --- a/test/tangent_types/notimplemented.jl +++ b/test/tangent_types/notimplemented.jl @@ -6,76 +6,51 @@ ni2 = ChainRulesCore.NotImplemented( @__MODULE__, LineNumberNode(@__LINE__, @__FILE__), "error2" ) - - # supported operations (for `@scalar_rule`) x, y, z = rand(3) + + # conjugate @test conj(ni) === ni - @test muladd(ni, y, z) === ni - @test muladd(ni, ZeroTangent(), z) == z - @test muladd(ni, y, ZeroTangent()) === ni - @test muladd(ni, ZeroTangent(), ZeroTangent()) == ZeroTangent() - @test muladd(ni, ni2, z) === ni - @test muladd(ni, ni2, ZeroTangent()) === ni - @test muladd(ni, y, ni2) === ni - @test muladd(ni, ZeroTangent(), ni2) === ni2 - @test muladd(x, ni, z) === ni - @test muladd(ZeroTangent(), ni, z) == z - @test muladd(x, ni, ZeroTangent()) === ni - @test muladd(ZeroTangent(), ni, ZeroTangent()) == ZeroTangent() - @test muladd(x, ni, ni2) === ni - @test muladd(ZeroTangent(), ni, ni2) === ni2 - @test muladd(x, y, ni) === ni - @test muladd(ZeroTangent(), y, ni) === ni - @test muladd(x, ZeroTangent(), ni) === ni - @test muladd(ZeroTangent(), ZeroTangent(), ni) === ni - @test ni + rand() === ni - @test ni + ZeroTangent() === ni - @test ni + NoTangent() === ni - @test ni + true === ni - @test ni + @thunk(x^2) === ni - @test rand() + ni === ni - @test ZeroTangent() + ni === ni - @test NoTangent() + ni === ni - @test true + ni === ni - @test @thunk(x^2) + ni === ni + + # addition + for a in (rand(), NoTangent(), ZeroTangent(), true, @thunk(x^2)) + @test ni + a === ni + @test a + ni === ni + end @test ni + ni2 === ni - @test ni * rand() === ni - @test ni * ZeroTangent() == ZeroTangent() - @test ZeroTangent() * ni == ZeroTangent() - @test dot(ni, ZeroTangent()) == ZeroTangent() - @test dot(ZeroTangent(), ni) == ZeroTangent() + @test ni2 + ni === ni2 + + # multiplication and dot product + for a in (rand(), true, @thunk(x^2)) + @test ni * a === ni + @test a * ni === ni + @test dot(ni, a) === ni + @test dot(a, ni) === ni + end + for a in (NoTangent(), ZeroTangent()) + @test ni * a === a + @test a * ni === a + @test dot(ni, a) === a + @test dot(a, ni) === a + end + @test ni * ni2 === ni + @test ni2 * ni === ni2 + @test dot(ni, ni2) === ni + @test dot(ni2, ni) === ni2 + + # broadcasting @test ni .* rand() === ni + @test rand() .* ni === ni @test broadcastable(ni) isa Ref{typeof(ni)} # unsupported operations E = ChainRulesCore.NotImplementedException @test_throws E +ni @test_throws E -ni - @test_throws E ni - rand() - @test_throws E ni - ZeroTangent() - @test_throws E ni - NoTangent() - @test_throws E ni - true - @test_throws E ni - @thunk(x^2) - @test_throws E rand() - ni - @test_throws E ZeroTangent() - ni - @test_throws E NoTangent() - ni - @test_throws E true - ni - @test_throws E @thunk(x^2) - ni + for a in (rand(), NoTangent(), ZeroTangent(), true, @thunk(x^2)) + @test_throws E ni - a + @test_throws E a - ni + end @test_throws E ni - ni2 - @test_throws E rand() * ni - @test_throws E NoTangent() * ni - @test_throws E true * ni - @test_throws E @thunk(x^2) * ni - @test_throws E ni * ni2 - @test_throws E dot(ni, rand()) - @test_throws E dot(ni, NoTangent()) - @test_throws E dot(ni, true) - @test_throws E dot(ni, @thunk(x^2)) - @test_throws E dot(rand(), ni) - @test_throws E dot(NoTangent(), ni) - @test_throws E dot(true, ni) - @test_throws E dot(@thunk(x^2), ni) - @test_throws E dot(ni, ni2) @test_throws E ni / rand() @test_throws E rand() / ni @test_throws E ni / ni2 From 392a49ee0e6bc290af48096ee849d5ed0c869d2d Mon Sep 17 00:00:00 2001 From: David Widmann Date: Sun, 3 Oct 2021 01:05:45 +0200 Subject: [PATCH 2/4] Bump version --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index a8289bc0e..112f68620 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "ChainRulesCore" uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" -version = "1.7.1" +version = "1.7.2" [deps] Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" From 4c6f1be9ee4383d715e63be0893a5a51d92feb08 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Sun, 3 Oct 2021 11:01:31 +0200 Subject: [PATCH 3/4] Add comment about subtraction --- src/tangent_arithmetic.jl | 4 ++-- test/tangent_types/notimplemented.jl | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/tangent_arithmetic.jl b/src/tangent_arithmetic.jl index 91a427999..9c1378aab 100644 --- a/src/tangent_arithmetic.jl +++ b/src/tangent_arithmetic.jl @@ -40,8 +40,8 @@ for T in (:AbstractThunk, :Tangent, :Any) @eval LinearAlgebra.dot(::$T, x::NotImplemented) = x end -# other common operations throw an exception -Base.:+(x::NotImplemented) = throw(NotImplementedException(x)) +# subtraction throws an exception: in AD we add tangents but do not subtract them +# subtraction happens eg. in gradient descent which can't be performed with `NotImplemented` Base.:-(x::NotImplemented) = throw(NotImplementedException(x)) Base.:-(x::NotImplemented, ::NotImplemented) = throw(NotImplementedException(x)) for T in (:ZeroTangent, :NoTangent, :AbstractThunk, :Tangent, :Any) diff --git a/test/tangent_types/notimplemented.jl b/test/tangent_types/notimplemented.jl index 873cb234f..d8361023f 100644 --- a/test/tangent_types/notimplemented.jl +++ b/test/tangent_types/notimplemented.jl @@ -16,6 +16,7 @@ @test ni + a === ni @test a + ni === ni end + @test +ni === ni @test ni + ni2 === ni @test ni2 + ni === ni2 @@ -44,7 +45,6 @@ # unsupported operations E = ChainRulesCore.NotImplementedException - @test_throws E +ni @test_throws E -ni for a in (rand(), NoTangent(), ZeroTangent(), true, @thunk(x^2)) @test_throws E ni - a From ea5c8899eafd4d428e438c12c8e44582f94819ab Mon Sep 17 00:00:00 2001 From: David Widmann Date: Sun, 3 Oct 2021 11:17:38 +0200 Subject: [PATCH 4/4] Simplify tests --- test/tangent_types/notimplemented.jl | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/test/tangent_types/notimplemented.jl b/test/tangent_types/notimplemented.jl index d8361023f..2fd337979 100644 --- a/test/tangent_types/notimplemented.jl +++ b/test/tangent_types/notimplemented.jl @@ -6,13 +6,14 @@ ni2 = ChainRulesCore.NotImplemented( @__MODULE__, LineNumberNode(@__LINE__, @__FILE__), "error2" ) - x, y, z = rand(3) + x = rand() + thunk = @thunk(x^2) # conjugate @test conj(ni) === ni # addition - for a in (rand(), NoTangent(), ZeroTangent(), true, @thunk(x^2)) + for a in (true, x, NoTangent(), ZeroTangent(), thunk) @test ni + a === ni @test a + ni === ni end @@ -21,7 +22,7 @@ @test ni2 + ni === ni2 # multiplication and dot product - for a in (rand(), true, @thunk(x^2)) + for a in (true, x, thunk) @test ni * a === ni @test a * ni === ni @test dot(ni, a) === ni @@ -39,20 +40,20 @@ @test dot(ni2, ni) === ni2 # broadcasting - @test ni .* rand() === ni - @test rand() .* ni === ni + @test ni .* x === ni + @test x .* ni === ni @test broadcastable(ni) isa Ref{typeof(ni)} # unsupported operations E = ChainRulesCore.NotImplementedException @test_throws E -ni - for a in (rand(), NoTangent(), ZeroTangent(), true, @thunk(x^2)) + for a in (true, x, NoTangent(), ZeroTangent(), thunk) @test_throws E ni - a @test_throws E a - ni end @test_throws E ni - ni2 - @test_throws E ni / rand() - @test_throws E rand() / ni + @test_throws E ni / x + @test_throws E x / ni @test_throws E ni / ni2 @test_throws E zero(ni) @test_throws E zero(typeof(ni))