From dbe7765455dc9cb54ad7f1f0dee8c050f7b340ee Mon Sep 17 00:00:00 2001 From: Yingbo Ma Date: Sat, 11 Jan 2020 17:18:21 -0500 Subject: [PATCH 1/7] Eagerly evaluate scalers rules Master behavior ```julia julia> @scalar_rule(one(x), Zero()) julia> frule(one, 1, Zero(), [1, 2]) (1, Zero()) julia> frule(one, 1, Zero(), One()) (1, Zero()) ``` Desirable behavior ```julia julia> @scalar_rule(one(x), Zero()) julia> frule(one, 1, Zero(), [1, 2]) (1, [0, 0]) julia> frule(one, 1, Zero(), One()) (1, Thunk(var"#8#10"()) ) ``` --- src/differential_arithmetic.jl | 9 +++++---- src/rule_definition_tools.jl | 3 ++- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/src/differential_arithmetic.jl b/src/differential_arithmetic.jl index 2d4434a93..3ee1723bf 100644 --- a/src/differential_arithmetic.jl +++ b/src/differential_arithmetic.jl @@ -57,11 +57,12 @@ end Base.:+(a::AbstractThunk, b::AbstractThunk) = unthunk(a) + unthunk(b) Base.:*(a::AbstractThunk, b::AbstractThunk) = unthunk(a) * unthunk(b) for T in (:Any,) - @eval Base.:+(a::AbstractThunk, b::$T) = unthunk(a) + b - @eval Base.:+(a::$T, b::AbstractThunk) = a + unthunk(b) + # we want to eagerly compute the result when thunk meets other types + @eval Base.:+(a::AbstractThunk, b::$T) = extern(a) + b + @eval Base.:+(a::$T, b::AbstractThunk) = a + extern(b) - @eval Base.:*(a::AbstractThunk, b::$T) = unthunk(a) * b - @eval Base.:*(a::$T, b::AbstractThunk) = a * unthunk(b) + @eval Base.:*(a::AbstractThunk, b::$T) = extern(a) * b + @eval Base.:*(a::$T, b::AbstractThunk) = a * extern(b) end ################## Composite ############################################################## diff --git a/src/rule_definition_tools.jl b/src/rule_definition_tools.jl index d631f6131..65e41c39e 100644 --- a/src/rule_definition_tools.jl +++ b/src/rule_definition_tools.jl @@ -209,7 +209,8 @@ function propagation_expr(Δs, ∂s) # This is basically Δs ⋅ ∂s ∂s = map(esc, ∂s) - ∂_mul_Δs = ntuple(i->:($(∂s[i]) * $(Δs[i])), length(∂s)) + # this is neccssary since we want to eagerly evaluate the result + ∂_mul_Δs = [:(@thunk($(∂s[i])) * $(Δs[i])) for i in 1:length(∂s)] return :(+($(∂_mul_Δs...))) end From 9abe813f926470d6c19dcc100e97a90f4a49ac26 Mon Sep 17 00:00:00 2001 From: Yingbo Ma Date: Sat, 11 Jan 2020 17:23:45 -0500 Subject: [PATCH 2/7] New release --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 2ed20dd34..18bb57a35 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "ChainRulesCore" uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" -version = "0.5.0" +version = "0.5.1" [compat] julia = "^1.0" From cf2bc6e3328ec8fcd93379769937d776276b9fae Mon Sep 17 00:00:00 2001 From: Yingbo Ma Date: Sat, 11 Jan 2020 17:33:41 -0500 Subject: [PATCH 3/7] Add tests --- test/rules.jl | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/test/rules.jl b/test/rules.jl index f280ebc59..1845f5739 100644 --- a/test/rules.jl +++ b/test/rules.jl @@ -8,6 +8,9 @@ cool(x, y) = x + y + 1 dummy_identity(x) = x @scalar_rule(dummy_identity(x), One()) +nice(x) = 1 +@scalar_rule(nice(x), Zero()) + ####### _second(t) = Base.tuple_type_head(Base.tuple_type_tail(t)) @@ -31,11 +34,16 @@ _second(t) = Base.tuple_type_head(Base.tuple_type_tail(t)) @test cool_methods == only_methods frx, cool_pushforward = frule(cool, 1, dself, 1) - @test frx == 2 - @test cool_pushforward == 1 + @test frx === 2 + @test cool_pushforward === 1 rrx, cool_pullback = rrule(cool, 1) self, rr1 = cool_pullback(1) - @test self == NO_FIELDS - @test rrx == 2 - @test rr1 == 1 + @test self === NO_FIELDS + @test rrx === 2 + @test rr1 === 1 + + frx, nice_pushforward = frule(nice, 1, dself, 1) + @test nice_pushforward === 0 + rrx, nice_pullback = rrule(nice, 1) + @test (NO_FIELDS, 0) === nice_pullback(1) end From d329ed8fbdd527b8a319479bf6b32b23d2bcb09d Mon Sep 17 00:00:00 2001 From: Yingbo Ma Date: Sun, 12 Jan 2020 14:36:07 -0500 Subject: [PATCH 4/7] Revert "Eagerly evaluate scalers rules" This reverts commit dbe7765455dc9cb54ad7f1f0dee8c050f7b340ee. --- src/differential_arithmetic.jl | 9 ++++----- src/rule_definition_tools.jl | 3 +-- 2 files changed, 5 insertions(+), 7 deletions(-) diff --git a/src/differential_arithmetic.jl b/src/differential_arithmetic.jl index 3ee1723bf..2d4434a93 100644 --- a/src/differential_arithmetic.jl +++ b/src/differential_arithmetic.jl @@ -57,12 +57,11 @@ end Base.:+(a::AbstractThunk, b::AbstractThunk) = unthunk(a) + unthunk(b) Base.:*(a::AbstractThunk, b::AbstractThunk) = unthunk(a) * unthunk(b) for T in (:Any,) - # we want to eagerly compute the result when thunk meets other types - @eval Base.:+(a::AbstractThunk, b::$T) = extern(a) + b - @eval Base.:+(a::$T, b::AbstractThunk) = a + extern(b) + @eval Base.:+(a::AbstractThunk, b::$T) = unthunk(a) + b + @eval Base.:+(a::$T, b::AbstractThunk) = a + unthunk(b) - @eval Base.:*(a::AbstractThunk, b::$T) = extern(a) * b - @eval Base.:*(a::$T, b::AbstractThunk) = a * extern(b) + @eval Base.:*(a::AbstractThunk, b::$T) = unthunk(a) * b + @eval Base.:*(a::$T, b::AbstractThunk) = a * unthunk(b) end ################## Composite ############################################################## diff --git a/src/rule_definition_tools.jl b/src/rule_definition_tools.jl index 65e41c39e..d631f6131 100644 --- a/src/rule_definition_tools.jl +++ b/src/rule_definition_tools.jl @@ -209,8 +209,7 @@ function propagation_expr(Δs, ∂s) # This is basically Δs ⋅ ∂s ∂s = map(esc, ∂s) - # this is neccssary since we want to eagerly evaluate the result - ∂_mul_Δs = [:(@thunk($(∂s[i])) * $(Δs[i])) for i in 1:length(∂s)] + ∂_mul_Δs = ntuple(i->:($(∂s[i]) * $(Δs[i])), length(∂s)) return :(+($(∂_mul_Δs...))) end From 0a66a95b3c56a82c660c03113039fcf3d0026639 Mon Sep 17 00:00:00 2001 From: Yingbo Ma Date: Sun, 12 Jan 2020 14:41:38 -0500 Subject: [PATCH 5/7] Redefine * between ::Zero and ::Any --- src/differential_arithmetic.jl | 10 +++++++--- test/differentials/zero.jl | 14 +++++++------- 2 files changed, 14 insertions(+), 10 deletions(-) diff --git a/src/differential_arithmetic.jl b/src/differential_arithmetic.jl index 2d4434a93..299616532 100644 --- a/src/differential_arithmetic.jl +++ b/src/differential_arithmetic.jl @@ -32,15 +32,19 @@ Base.:*(::DoesNotExist, ::Zero) = Zero() Base.:*(::Zero, ::DoesNotExist) = Zero() -Base.:+(::Zero, b::Zero) = Zero() +Base.:+(::Zero, ::Zero) = Zero() Base.:*(::Zero, ::Zero) = Zero() for T in (:One, :AbstractThunk, :Any) @eval Base.:+(::Zero, b::$T) = b @eval Base.:+(a::$T, ::Zero) = a - @eval Base.:*(::Zero, ::$T) = Zero() - @eval Base.:*(::$T, ::Zero) = Zero() + if T !== :Any + @eval Base.:*(::Zero, ::$T) = Zero() + @eval Base.:*(::$T, ::Zero) = Zero() + end end +Base.:*(::Zero, x) = zero(x) +Base.:*(x, ::Zero) = zero(x) Base.:+(a::One, b::One) = extern(a) + extern(b) diff --git a/test/differentials/zero.jl b/test/differentials/zero.jl index 54c16367c..24d939d97 100644 --- a/test/differentials/zero.jl +++ b/test/differentials/zero.jl @@ -1,15 +1,15 @@ @testset "Zero" begin z = Zero() @test extern(z) === false - @test z + z == z - @test z + 1 == 1 - @test 1 + z == 1 - @test z * z == z - @test z * 1 == z - @test 1 * z == z + @test z + z === z + @test z + 1 === 1 + @test 1 + z === 1 + @test z * z === z + @test z * 1 === 0 + @test 1 * z === 0 for x in z @test x === z end @test broadcastable(z) isa Ref{Zero} - @test conj(z) == z + @test conj(z) === z end From 6a142594490c26ffb1e6a33852c493a4e7567a41 Mon Sep 17 00:00:00 2001 From: Yingbo Ma Date: Sun, 12 Jan 2020 15:37:48 -0500 Subject: [PATCH 6/7] Make it nicer --- src/differential_arithmetic.jl | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/src/differential_arithmetic.jl b/src/differential_arithmetic.jl index 299616532..fc7680803 100644 --- a/src/differential_arithmetic.jl +++ b/src/differential_arithmetic.jl @@ -38,14 +38,10 @@ for T in (:One, :AbstractThunk, :Any) @eval Base.:+(::Zero, b::$T) = b @eval Base.:+(a::$T, ::Zero) = a - if T !== :Any - @eval Base.:*(::Zero, ::$T) = Zero() - @eval Base.:*(::$T, ::Zero) = Zero() - end + @eval Base.:*(::Zero, x::$T) = zero(x) + @eval Base.:*(x::$T, ::Zero) = zero(x) end -Base.:*(::Zero, x) = zero(x) -Base.:*(x, ::Zero) = zero(x) - +Base.zero(::AbstractDifferential) = Zero() Base.:+(a::One, b::One) = extern(a) + extern(b) Base.:*(::One, ::One) = One() From df2eee0c9b27d35fce19bb20bb1312da95b22bea Mon Sep 17 00:00:00 2001 From: Yingbo Ma Date: Sun, 12 Jan 2020 15:47:14 -0500 Subject: [PATCH 7/7] Add tests and move zero(::AbstractDifferential) to the right folder --- src/differential_arithmetic.jl | 1 - src/differentials/zero.jl | 2 ++ test/differentials/zero.jl | 4 ++++ 3 files changed, 6 insertions(+), 1 deletion(-) diff --git a/src/differential_arithmetic.jl b/src/differential_arithmetic.jl index fc7680803..985f1204f 100644 --- a/src/differential_arithmetic.jl +++ b/src/differential_arithmetic.jl @@ -41,7 +41,6 @@ for T in (:One, :AbstractThunk, :Any) @eval Base.:*(::Zero, x::$T) = zero(x) @eval Base.:*(x::$T, ::Zero) = zero(x) end -Base.zero(::AbstractDifferential) = Zero() Base.:+(a::One, b::One) = extern(a) + extern(b) Base.:*(::One, ::One) = One() diff --git a/src/differentials/zero.jl b/src/differentials/zero.jl index 3e5312a72..2249a0b95 100644 --- a/src/differentials/zero.jl +++ b/src/differentials/zero.jl @@ -11,3 +11,5 @@ Base.Broadcast.broadcastable(::Zero) = Ref(Zero()) Base.iterate(x::Zero) = (x, nothing) Base.iterate(::Zero, ::Any) = nothing + +Base.zero(::AbstractDifferential) = Zero() diff --git a/test/differentials/zero.jl b/test/differentials/zero.jl index 24d939d97..30f89a76f 100644 --- a/test/differentials/zero.jl +++ b/test/differentials/zero.jl @@ -12,4 +12,8 @@ end @test broadcastable(z) isa Ref{Zero} @test conj(z) === z + @test zero(@thunk(3)) === z + @test zero(One()) === z + @test zero(DoesNotExist()) === z + @test zero(Composite{Tuple{Int,Int}}((1, 2))) === z end