diff --git a/Project.toml b/Project.toml index 2ed20dd34..18bb57a35 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "ChainRulesCore" uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" -version = "0.5.0" +version = "0.5.1" [compat] julia = "^1.0" diff --git a/src/differential_arithmetic.jl b/src/differential_arithmetic.jl index 2d4434a93..985f1204f 100644 --- a/src/differential_arithmetic.jl +++ b/src/differential_arithmetic.jl @@ -32,17 +32,16 @@ Base.:*(::DoesNotExist, ::Zero) = Zero() Base.:*(::Zero, ::DoesNotExist) = Zero() -Base.:+(::Zero, b::Zero) = Zero() +Base.:+(::Zero, ::Zero) = Zero() Base.:*(::Zero, ::Zero) = Zero() for T in (:One, :AbstractThunk, :Any) @eval Base.:+(::Zero, b::$T) = b @eval Base.:+(a::$T, ::Zero) = a - @eval Base.:*(::Zero, ::$T) = Zero() - @eval Base.:*(::$T, ::Zero) = Zero() + @eval Base.:*(::Zero, x::$T) = zero(x) + @eval Base.:*(x::$T, ::Zero) = zero(x) end - Base.:+(a::One, b::One) = extern(a) + extern(b) Base.:*(::One, ::One) = One() for T in (:AbstractThunk, :Any) diff --git a/src/differentials/zero.jl b/src/differentials/zero.jl index 3e5312a72..2249a0b95 100644 --- a/src/differentials/zero.jl +++ b/src/differentials/zero.jl @@ -11,3 +11,5 @@ Base.Broadcast.broadcastable(::Zero) = Ref(Zero()) Base.iterate(x::Zero) = (x, nothing) Base.iterate(::Zero, ::Any) = nothing + +Base.zero(::AbstractDifferential) = Zero() diff --git a/test/differentials/zero.jl b/test/differentials/zero.jl index 54c16367c..30f89a76f 100644 --- a/test/differentials/zero.jl +++ b/test/differentials/zero.jl @@ -1,15 +1,19 @@ @testset "Zero" begin z = Zero() @test extern(z) === false - @test z + z == z - @test z + 1 == 1 - @test 1 + z == 1 - @test z * z == z - @test z * 1 == z - @test 1 * z == z + @test z + z === z + @test z + 1 === 1 + @test 1 + z === 1 + @test z * z === z + @test z * 1 === 0 + @test 1 * z === 0 for x in z @test x === z end @test broadcastable(z) isa Ref{Zero} - @test conj(z) == z + @test conj(z) === z + @test zero(@thunk(3)) === z + @test zero(One()) === z + @test zero(DoesNotExist()) === z + @test zero(Composite{Tuple{Int,Int}}((1, 2))) === z end diff --git a/test/rules.jl b/test/rules.jl index f280ebc59..1845f5739 100644 --- a/test/rules.jl +++ b/test/rules.jl @@ -8,6 +8,9 @@ cool(x, y) = x + y + 1 dummy_identity(x) = x @scalar_rule(dummy_identity(x), One()) +nice(x) = 1 +@scalar_rule(nice(x), Zero()) + ####### _second(t) = Base.tuple_type_head(Base.tuple_type_tail(t)) @@ -31,11 +34,16 @@ _second(t) = Base.tuple_type_head(Base.tuple_type_tail(t)) @test cool_methods == only_methods frx, cool_pushforward = frule(cool, 1, dself, 1) - @test frx == 2 - @test cool_pushforward == 1 + @test frx === 2 + @test cool_pushforward === 1 rrx, cool_pullback = rrule(cool, 1) self, rr1 = cool_pullback(1) - @test self == NO_FIELDS - @test rrx == 2 - @test rr1 == 1 + @test self === NO_FIELDS + @test rrx === 2 + @test rr1 === 1 + + frx, nice_pushforward = frule(nice, 1, dself, 1) + @test nice_pushforward === 0 + rrx, nice_pullback = rrule(nice, 1) + @test (NO_FIELDS, 0) === nice_pullback(1) end