Skip to content

Commit de61cd3

Browse files
authored
Eagerly evaluate ::Zero * ::Any (#90)
* Eagerly evaluate scalers rules Master behavior ```julia julia> @scalar_rule(one(x), Zero()) julia> frule(one, 1, Zero(), [1, 2]) (1, Zero()) julia> frule(one, 1, Zero(), One()) (1, Zero()) ``` Desirable behavior ```julia julia> @scalar_rule(one(x), Zero()) julia> frule(one, 1, Zero(), [1, 2]) (1, [0, 0]) julia> frule(one, 1, Zero(), One()) (1, Thunk(var"#8#10"()) ) ``` * New release * Add tests * Revert "Eagerly evaluate scalers rules" This reverts commit dbe7765. * Redefine * between ::Zero and ::Any * Make it nicer * Add tests and move zero(::AbstractDifferential) to the right folder
1 parent 0976087 commit de61cd3

File tree

5 files changed

+30
-17
lines changed

5 files changed

+30
-17
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "ChainRulesCore"
22
uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
3-
version = "0.5.0"
3+
version = "0.5.1"
44

55
[compat]
66
julia = "^1.0"

src/differential_arithmetic.jl

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,17 +32,16 @@ Base.:*(::DoesNotExist, ::Zero) = Zero()
3232
Base.:*(::Zero, ::DoesNotExist) = Zero()
3333

3434

35-
Base.:+(::Zero, b::Zero) = Zero()
35+
Base.:+(::Zero, ::Zero) = Zero()
3636
Base.:*(::Zero, ::Zero) = Zero()
3737
for T in (:One, :AbstractThunk, :Any)
3838
@eval Base.:+(::Zero, b::$T) = b
3939
@eval Base.:+(a::$T, ::Zero) = a
4040

41-
@eval Base.:*(::Zero, ::$T) = Zero()
42-
@eval Base.:*(::$T, ::Zero) = Zero()
41+
@eval Base.:*(::Zero, x::$T) = zero(x)
42+
@eval Base.:*(x::$T, ::Zero) = zero(x)
4343
end
4444

45-
4645
Base.:+(a::One, b::One) = extern(a) + extern(b)
4746
Base.:*(::One, ::One) = One()
4847
for T in (:AbstractThunk, :Any)

src/differentials/zero.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,3 +11,5 @@ Base.Broadcast.broadcastable(::Zero) = Ref(Zero())
1111

1212
Base.iterate(x::Zero) = (x, nothing)
1313
Base.iterate(::Zero, ::Any) = nothing
14+
15+
Base.zero(::AbstractDifferential) = Zero()

test/differentials/zero.jl

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,19 @@
11
@testset "Zero" begin
22
z = Zero()
33
@test extern(z) === false
4-
@test z + z == z
5-
@test z + 1 == 1
6-
@test 1 + z == 1
7-
@test z * z == z
8-
@test z * 1 == z
9-
@test 1 * z == z
4+
@test z + z === z
5+
@test z + 1 === 1
6+
@test 1 + z === 1
7+
@test z * z === z
8+
@test z * 1 === 0
9+
@test 1 * z === 0
1010
for x in z
1111
@test x === z
1212
end
1313
@test broadcastable(z) isa Ref{Zero}
14-
@test conj(z) == z
14+
@test conj(z) === z
15+
@test zero(@thunk(3)) === z
16+
@test zero(One()) === z
17+
@test zero(DoesNotExist()) === z
18+
@test zero(Composite{Tuple{Int,Int}}((1, 2))) === z
1519
end

test/rules.jl

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,9 @@ cool(x, y) = x + y + 1
88
dummy_identity(x) = x
99
@scalar_rule(dummy_identity(x), One())
1010

11+
nice(x) = 1
12+
@scalar_rule(nice(x), Zero())
13+
1114
#######
1215

1316
_second(t) = Base.tuple_type_head(Base.tuple_type_tail(t))
@@ -31,11 +34,16 @@ _second(t) = Base.tuple_type_head(Base.tuple_type_tail(t))
3134
@test cool_methods == only_methods
3235

3336
frx, cool_pushforward = frule(cool, 1, dself, 1)
34-
@test frx == 2
35-
@test cool_pushforward == 1
37+
@test frx === 2
38+
@test cool_pushforward === 1
3639
rrx, cool_pullback = rrule(cool, 1)
3740
self, rr1 = cool_pullback(1)
38-
@test self == NO_FIELDS
39-
@test rrx == 2
40-
@test rr1 == 1
41+
@test self === NO_FIELDS
42+
@test rrx === 2
43+
@test rr1 === 1
44+
45+
frx, nice_pushforward = frule(nice, 1, dself, 1)
46+
@test nice_pushforward === 0
47+
rrx, nice_pullback = rrule(nice, 1)
48+
@test (NO_FIELDS, 0) === nice_pullback(1)
4149
end

0 commit comments

Comments
 (0)