diff --git a/Project.toml b/Project.toml index a8289bc0e..112f68620 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "ChainRulesCore" uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" -version = "1.7.1" +version = "1.7.2" [deps] Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" diff --git a/src/tangent_arithmetic.jl b/src/tangent_arithmetic.jl index 03c831396..9c1378aab 100644 --- a/src/tangent_arithmetic.jl +++ b/src/tangent_arithmetic.jl @@ -16,54 +16,37 @@ Notice: # we propagate `NotImplemented` (e.g., in `@scalar_rule`) # this requires the following definitions (see also #337) -Base.:+(x::NotImplemented, ::ZeroTangent) = x -Base.:+(::ZeroTangent, x::NotImplemented) = x Base.:+(x::NotImplemented, ::NotImplemented) = x -Base.:*(::NotImplemented, ::ZeroTangent) = ZeroTangent() -Base.:*(::ZeroTangent, ::NotImplemented) = ZeroTangent() -for T in (:NoTangent, :AbstractThunk, :Tangent, :Any) +Base.:*(x::NotImplemented, ::NotImplemented) = x +LinearAlgebra.dot(x::NotImplemented, ::NotImplemented) = x +# `NotImplemented` always "wins" + +for T in (:ZeroTangent, :NoTangent, :AbstractThunk, :Tangent, :Any) @eval Base.:+(x::NotImplemented, ::$T) = x @eval Base.:+(::$T, x::NotImplemented) = x +end +# `NotImplemented` "loses" * and dot against NoTangent and ZeroTangent +# this can be used to ignore partial derivatives that are not implemented +for T in (:ZeroTangent, :NoTangent) + @eval Base.:*(::NotImplemented, ::$T) = $T() + @eval Base.:*(::$T, ::NotImplemented) = $T() + @eval LinearAlgebra.dot(::NotImplemented, ::$T) = $T() + @eval LinearAlgebra.dot(::$T, ::NotImplemented) = $T() +end +# `NotImplemented` "wins" * and dot for other types +for T in (:AbstractThunk, :Tangent, :Any) @eval Base.:*(x::NotImplemented, ::$T) = x + @eval Base.:*(::$T, x::NotImplemented) = x + @eval LinearAlgebra.dot(x::NotImplemented, ::$T) = x + @eval LinearAlgebra.dot(::$T, x::NotImplemented) = x end -Base.muladd(x::NotImplemented, y, z) = x -Base.muladd(::NotImplemented, ::ZeroTangent, z) = z -Base.muladd(x::NotImplemented, y, ::ZeroTangent) = x -Base.muladd(::NotImplemented, ::ZeroTangent, ::ZeroTangent) = ZeroTangent() -Base.muladd(x, y::NotImplemented, z) = y -Base.muladd(::ZeroTangent, ::NotImplemented, z) = z -Base.muladd(x, y::NotImplemented, ::ZeroTangent) = y -Base.muladd(::ZeroTangent, ::NotImplemented, ::ZeroTangent) = ZeroTangent() -Base.muladd(x, y, z::NotImplemented) = z -Base.muladd(::ZeroTangent, y, z::NotImplemented) = z -Base.muladd(x, ::ZeroTangent, z::NotImplemented) = z -Base.muladd(::ZeroTangent, ::ZeroTangent, z::NotImplemented) = z -Base.muladd(x::NotImplemented, ::NotImplemented, z) = x -Base.muladd(x::NotImplemented, ::NotImplemented, ::ZeroTangent) = x -Base.muladd(x::NotImplemented, y, ::NotImplemented) = x -Base.muladd(::NotImplemented, ::ZeroTangent, z::NotImplemented) = z -Base.muladd(x, y::NotImplemented, ::NotImplemented) = y -Base.muladd(::ZeroTangent, ::NotImplemented, z::NotImplemented) = z -Base.muladd(x::NotImplemented, ::NotImplemented, ::NotImplemented) = x -LinearAlgebra.dot(::NotImplemented, ::ZeroTangent) = ZeroTangent() -LinearAlgebra.dot(::ZeroTangent, ::NotImplemented) = ZeroTangent() - -# other common operations throw an exception -Base.:+(x::NotImplemented) = throw(NotImplementedException(x)) + +# subtraction throws an exception: in AD we add tangents but do not subtract them +# subtraction happens eg. in gradient descent which can't be performed with `NotImplemented` Base.:-(x::NotImplemented) = throw(NotImplementedException(x)) -Base.:-(x::NotImplemented, ::ZeroTangent) = throw(NotImplementedException(x)) -Base.:-(::ZeroTangent, x::NotImplemented) = throw(NotImplementedException(x)) Base.:-(x::NotImplemented, ::NotImplemented) = throw(NotImplementedException(x)) -Base.:*(x::NotImplemented, ::NotImplemented) = throw(NotImplementedException(x)) -function LinearAlgebra.dot(x::NotImplemented, ::NotImplemented) - return throw(NotImplementedException(x)) -end -for T in (:NoTangent, :AbstractThunk, :Tangent, :Any) +for T in (:ZeroTangent, :NoTangent, :AbstractThunk, :Tangent, :Any) @eval Base.:-(x::NotImplemented, ::$T) = throw(NotImplementedException(x)) @eval Base.:-(::$T, x::NotImplemented) = throw(NotImplementedException(x)) - @eval Base.:*(::$T, x::NotImplemented) = throw(NotImplementedException(x)) - @eval LinearAlgebra.dot(x::NotImplemented, ::$T) = throw(NotImplementedException(x)) - @eval LinearAlgebra.dot(::$T, x::NotImplemented) = throw(NotImplementedException(x)) end Base.:+(::NoTangent, ::NoTangent) = NoTangent() diff --git a/test/tangent_types/notimplemented.jl b/test/tangent_types/notimplemented.jl index fe3668518..2fd337979 100644 --- a/test/tangent_types/notimplemented.jl +++ b/test/tangent_types/notimplemented.jl @@ -6,78 +6,54 @@ ni2 = ChainRulesCore.NotImplemented( @__MODULE__, LineNumberNode(@__LINE__, @__FILE__), "error2" ) + x = rand() + thunk = @thunk(x^2) - # supported operations (for `@scalar_rule`) - x, y, z = rand(3) + # conjugate @test conj(ni) === ni - @test muladd(ni, y, z) === ni - @test muladd(ni, ZeroTangent(), z) == z - @test muladd(ni, y, ZeroTangent()) === ni - @test muladd(ni, ZeroTangent(), ZeroTangent()) == ZeroTangent() - @test muladd(ni, ni2, z) === ni - @test muladd(ni, ni2, ZeroTangent()) === ni - @test muladd(ni, y, ni2) === ni - @test muladd(ni, ZeroTangent(), ni2) === ni2 - @test muladd(x, ni, z) === ni - @test muladd(ZeroTangent(), ni, z) == z - @test muladd(x, ni, ZeroTangent()) === ni - @test muladd(ZeroTangent(), ni, ZeroTangent()) == ZeroTangent() - @test muladd(x, ni, ni2) === ni - @test muladd(ZeroTangent(), ni, ni2) === ni2 - @test muladd(x, y, ni) === ni - @test muladd(ZeroTangent(), y, ni) === ni - @test muladd(x, ZeroTangent(), ni) === ni - @test muladd(ZeroTangent(), ZeroTangent(), ni) === ni - @test ni + rand() === ni - @test ni + ZeroTangent() === ni - @test ni + NoTangent() === ni - @test ni + true === ni - @test ni + @thunk(x^2) === ni - @test rand() + ni === ni - @test ZeroTangent() + ni === ni - @test NoTangent() + ni === ni - @test true + ni === ni - @test @thunk(x^2) + ni === ni + + # addition + for a in (true, x, NoTangent(), ZeroTangent(), thunk) + @test ni + a === ni + @test a + ni === ni + end + @test +ni === ni @test ni + ni2 === ni - @test ni * rand() === ni - @test ni * ZeroTangent() == ZeroTangent() - @test ZeroTangent() * ni == ZeroTangent() - @test dot(ni, ZeroTangent()) == ZeroTangent() - @test dot(ZeroTangent(), ni) == ZeroTangent() - @test ni .* rand() === ni + @test ni2 + ni === ni2 + + # multiplication and dot product + for a in (true, x, thunk) + @test ni * a === ni + @test a * ni === ni + @test dot(ni, a) === ni + @test dot(a, ni) === ni + end + for a in (NoTangent(), ZeroTangent()) + @test ni * a === a + @test a * ni === a + @test dot(ni, a) === a + @test dot(a, ni) === a + end + @test ni * ni2 === ni + @test ni2 * ni === ni2 + @test dot(ni, ni2) === ni + @test dot(ni2, ni) === ni2 + + # broadcasting + @test ni .* x === ni + @test x .* ni === ni @test broadcastable(ni) isa Ref{typeof(ni)} # unsupported operations E = ChainRulesCore.NotImplementedException - @test_throws E +ni @test_throws E -ni - @test_throws E ni - rand() - @test_throws E ni - ZeroTangent() - @test_throws E ni - NoTangent() - @test_throws E ni - true - @test_throws E ni - @thunk(x^2) - @test_throws E rand() - ni - @test_throws E ZeroTangent() - ni - @test_throws E NoTangent() - ni - @test_throws E true - ni - @test_throws E @thunk(x^2) - ni + for a in (true, x, NoTangent(), ZeroTangent(), thunk) + @test_throws E ni - a + @test_throws E a - ni + end @test_throws E ni - ni2 - @test_throws E rand() * ni - @test_throws E NoTangent() * ni - @test_throws E true * ni - @test_throws E @thunk(x^2) * ni - @test_throws E ni * ni2 - @test_throws E dot(ni, rand()) - @test_throws E dot(ni, NoTangent()) - @test_throws E dot(ni, true) - @test_throws E dot(ni, @thunk(x^2)) - @test_throws E dot(rand(), ni) - @test_throws E dot(NoTangent(), ni) - @test_throws E dot(true, ni) - @test_throws E dot(@thunk(x^2), ni) - @test_throws E dot(ni, ni2) - @test_throws E ni / rand() - @test_throws E rand() / ni + @test_throws E ni / x + @test_throws E x / ni @test_throws E ni / ni2 @test_throws E zero(ni) @test_throws E zero(typeof(ni))