Skip to content

Simplify arithmetic of NotImplemented and treat NoTangent like ZeroTangent #477

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Oct 3, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "ChainRulesCore"
uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
version = "1.7.1"
version = "1.7.2"

[deps]
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
Expand Down
61 changes: 22 additions & 39 deletions src/tangent_arithmetic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,54 +16,37 @@ Notice:

# we propagate `NotImplemented` (e.g., in `@scalar_rule`)
# this requires the following definitions (see also #337)
Base.:+(x::NotImplemented, ::ZeroTangent) = x
Base.:+(::ZeroTangent, x::NotImplemented) = x
Base.:+(x::NotImplemented, ::NotImplemented) = x
Base.:*(::NotImplemented, ::ZeroTangent) = ZeroTangent()
Base.:*(::ZeroTangent, ::NotImplemented) = ZeroTangent()
for T in (:NoTangent, :AbstractThunk, :Tangent, :Any)
Base.:*(x::NotImplemented, ::NotImplemented) = x
LinearAlgebra.dot(x::NotImplemented, ::NotImplemented) = x
# `NotImplemented` always "wins" +
for T in (:ZeroTangent, :NoTangent, :AbstractThunk, :Tangent, :Any)
@eval Base.:+(x::NotImplemented, ::$T) = x
@eval Base.:+(::$T, x::NotImplemented) = x
end
# `NotImplemented` "loses" * and dot against NoTangent and ZeroTangent
# this can be used to ignore partial derivatives that are not implemented
for T in (:ZeroTangent, :NoTangent)
@eval Base.:*(::NotImplemented, ::$T) = $T()
@eval Base.:*(::$T, ::NotImplemented) = $T()
@eval LinearAlgebra.dot(::NotImplemented, ::$T) = $T()
@eval LinearAlgebra.dot(::$T, ::NotImplemented) = $T()
end
# `NotImplemented` "wins" * and dot for other types
for T in (:AbstractThunk, :Tangent, :Any)
@eval Base.:*(x::NotImplemented, ::$T) = x
@eval Base.:*(::$T, x::NotImplemented) = x
@eval LinearAlgebra.dot(x::NotImplemented, ::$T) = x
@eval LinearAlgebra.dot(::$T, x::NotImplemented) = x
end
Base.muladd(x::NotImplemented, y, z) = x
Base.muladd(::NotImplemented, ::ZeroTangent, z) = z
Base.muladd(x::NotImplemented, y, ::ZeroTangent) = x
Base.muladd(::NotImplemented, ::ZeroTangent, ::ZeroTangent) = ZeroTangent()
Base.muladd(x, y::NotImplemented, z) = y
Base.muladd(::ZeroTangent, ::NotImplemented, z) = z
Base.muladd(x, y::NotImplemented, ::ZeroTangent) = y
Base.muladd(::ZeroTangent, ::NotImplemented, ::ZeroTangent) = ZeroTangent()
Base.muladd(x, y, z::NotImplemented) = z
Base.muladd(::ZeroTangent, y, z::NotImplemented) = z
Base.muladd(x, ::ZeroTangent, z::NotImplemented) = z
Base.muladd(::ZeroTangent, ::ZeroTangent, z::NotImplemented) = z
Base.muladd(x::NotImplemented, ::NotImplemented, z) = x
Base.muladd(x::NotImplemented, ::NotImplemented, ::ZeroTangent) = x
Base.muladd(x::NotImplemented, y, ::NotImplemented) = x
Base.muladd(::NotImplemented, ::ZeroTangent, z::NotImplemented) = z
Base.muladd(x, y::NotImplemented, ::NotImplemented) = y
Base.muladd(::ZeroTangent, ::NotImplemented, z::NotImplemented) = z
Base.muladd(x::NotImplemented, ::NotImplemented, ::NotImplemented) = x
LinearAlgebra.dot(::NotImplemented, ::ZeroTangent) = ZeroTangent()
LinearAlgebra.dot(::ZeroTangent, ::NotImplemented) = ZeroTangent()

# other common operations throw an exception
Base.:+(x::NotImplemented) = throw(NotImplementedException(x))

# subtraction throws an exception: in AD we add tangents but do not subtract them
# subtraction happens eg. in gradient descent which can't be performed with `NotImplemented`
Base.:-(x::NotImplemented) = throw(NotImplementedException(x))
Base.:-(x::NotImplemented, ::ZeroTangent) = throw(NotImplementedException(x))
Base.:-(::ZeroTangent, x::NotImplemented) = throw(NotImplementedException(x))
Base.:-(x::NotImplemented, ::NotImplemented) = throw(NotImplementedException(x))
Base.:*(x::NotImplemented, ::NotImplemented) = throw(NotImplementedException(x))
function LinearAlgebra.dot(x::NotImplemented, ::NotImplemented)
return throw(NotImplementedException(x))
end
for T in (:NoTangent, :AbstractThunk, :Tangent, :Any)
for T in (:ZeroTangent, :NoTangent, :AbstractThunk, :Tangent, :Any)
@eval Base.:-(x::NotImplemented, ::$T) = throw(NotImplementedException(x))
@eval Base.:-(::$T, x::NotImplemented) = throw(NotImplementedException(x))
@eval Base.:*(::$T, x::NotImplemented) = throw(NotImplementedException(x))
@eval LinearAlgebra.dot(x::NotImplemented, ::$T) = throw(NotImplementedException(x))
@eval LinearAlgebra.dot(::$T, x::NotImplemented) = throw(NotImplementedException(x))
end

Base.:+(::NoTangent, ::NoTangent) = NoTangent()
Expand Down
102 changes: 39 additions & 63 deletions test/tangent_types/notimplemented.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,78 +6,54 @@
ni2 = ChainRulesCore.NotImplemented(
@__MODULE__, LineNumberNode(@__LINE__, @__FILE__), "error2"
)
x = rand()
thunk = @thunk(x^2)

# supported operations (for `@scalar_rule`)
x, y, z = rand(3)
# conjugate
@test conj(ni) === ni
@test muladd(ni, y, z) === ni
@test muladd(ni, ZeroTangent(), z) == z
@test muladd(ni, y, ZeroTangent()) === ni
@test muladd(ni, ZeroTangent(), ZeroTangent()) == ZeroTangent()
@test muladd(ni, ni2, z) === ni
@test muladd(ni, ni2, ZeroTangent()) === ni
@test muladd(ni, y, ni2) === ni
@test muladd(ni, ZeroTangent(), ni2) === ni2
@test muladd(x, ni, z) === ni
@test muladd(ZeroTangent(), ni, z) == z
@test muladd(x, ni, ZeroTangent()) === ni
@test muladd(ZeroTangent(), ni, ZeroTangent()) == ZeroTangent()
@test muladd(x, ni, ni2) === ni
@test muladd(ZeroTangent(), ni, ni2) === ni2
@test muladd(x, y, ni) === ni
@test muladd(ZeroTangent(), y, ni) === ni
@test muladd(x, ZeroTangent(), ni) === ni
@test muladd(ZeroTangent(), ZeroTangent(), ni) === ni
@test ni + rand() === ni
@test ni + ZeroTangent() === ni
@test ni + NoTangent() === ni
@test ni + true === ni
@test ni + @thunk(x^2) === ni
@test rand() + ni === ni
@test ZeroTangent() + ni === ni
@test NoTangent() + ni === ni
@test true + ni === ni
@test @thunk(x^2) + ni === ni

# addition
for a in (true, x, NoTangent(), ZeroTangent(), thunk)
@test ni + a === ni
@test a + ni === ni
end
@test +ni === ni
@test ni + ni2 === ni
@test ni * rand() === ni
@test ni * ZeroTangent() == ZeroTangent()
@test ZeroTangent() * ni == ZeroTangent()
@test dot(ni, ZeroTangent()) == ZeroTangent()
@test dot(ZeroTangent(), ni) == ZeroTangent()
@test ni .* rand() === ni
@test ni2 + ni === ni2

# multiplication and dot product
for a in (true, x, thunk)
@test ni * a === ni
@test a * ni === ni
@test dot(ni, a) === ni
@test dot(a, ni) === ni
end
for a in (NoTangent(), ZeroTangent())
@test ni * a === a
@test a * ni === a
@test dot(ni, a) === a
@test dot(a, ni) === a
end
@test ni * ni2 === ni
@test ni2 * ni === ni2
@test dot(ni, ni2) === ni
@test dot(ni2, ni) === ni2

# broadcasting
@test ni .* x === ni
@test x .* ni === ni
@test broadcastable(ni) isa Ref{typeof(ni)}

# unsupported operations
E = ChainRulesCore.NotImplementedException
@test_throws E +ni
@test_throws E -ni
@test_throws E ni - rand()
@test_throws E ni - ZeroTangent()
@test_throws E ni - NoTangent()
@test_throws E ni - true
@test_throws E ni - @thunk(x^2)
@test_throws E rand() - ni
@test_throws E ZeroTangent() - ni
@test_throws E NoTangent() - ni
@test_throws E true - ni
@test_throws E @thunk(x^2) - ni
for a in (true, x, NoTangent(), ZeroTangent(), thunk)
@test_throws E ni - a
@test_throws E a - ni
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems funky that adding is fine but subtracting throws.

Like I guess it's fine -- adding tangents is what happens in AD, subtracting vis not.

Subtracting happen in gradient descent.

Ok I am convinced this is the behaviour we want.
Do we have a comment saying as such?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll add a comment.

end
@test_throws E ni - ni2
@test_throws E rand() * ni
@test_throws E NoTangent() * ni
@test_throws E true * ni
@test_throws E @thunk(x^2) * ni
@test_throws E ni * ni2
@test_throws E dot(ni, rand())
@test_throws E dot(ni, NoTangent())
@test_throws E dot(ni, true)
@test_throws E dot(ni, @thunk(x^2))
@test_throws E dot(rand(), ni)
@test_throws E dot(NoTangent(), ni)
@test_throws E dot(true, ni)
@test_throws E dot(@thunk(x^2), ni)
@test_throws E dot(ni, ni2)
@test_throws E ni / rand()
@test_throws E rand() / ni
@test_throws E ni / x
@test_throws E x / ni
@test_throws E ni / ni2
@test_throws E zero(ni)
@test_throws E zero(typeof(ni))
Expand Down