Skip to content

Commit 9c8fcd2

Browse files
authored
Fix some method ambiguities (#589)
* Fix some method ambiguities * Fixes * Update src/tangent_arithmetic.jl
1 parent f6123ee commit 9c8fcd2

File tree

5 files changed

+24
-4
lines changed

5 files changed

+24
-4
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "ChainRulesCore"
22
uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
3-
version = "1.15.5"
3+
version = "1.15.6"
44

55
[deps]
66
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"

src/tangent_types/notimplemented.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,11 @@ Base.:/(x::NotImplemented, ::Any) = throw(NotImplementedException(x))
4343
Base.:/(::Any, x::NotImplemented) = throw(NotImplementedException(x))
4444
Base.:/(x::NotImplemented, ::NotImplemented) = throw(NotImplementedException(x))
4545

46+
# Fix method ambiguity errors (#589)
47+
Base.:/(x::AbstractZero, ::NotImplemented) = x
48+
Base.:/(x::NotImplemented, ::AbstractThunk) = throw(NotImplementedException(x))
49+
Base.:/(::AbstractThunk, x::NotImplemented) = throw(NotImplementedException(x))
50+
4651
Base.zero(x::NotImplemented) = throw(NotImplementedException(x))
4752
function Base.zero(::Type{<:NotImplemented})
4853
return throw(

src/tangent_types/thunks.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,13 @@ Base.:(==)(a::AbstractThunk, b::AbstractThunk) = unthunk(a) == unthunk(b)
3333
Base.:(-)(a::AbstractThunk) = -unthunk(a)
3434
Base.:(-)(a::AbstractThunk, b) = unthunk(a) - b
3535
Base.:(-)(a, b::AbstractThunk) = a - unthunk(b)
36+
Base.:(-)(a::AbstractThunk, b::AbstractThunk) = unthunk(a) - unthunk(b)
3637
Base.:(/)(a::AbstractThunk, b) = unthunk(a) / b
3738
Base.:(/)(a, b::AbstractThunk) = a / unthunk(b)
39+
Base.:(/)(a::AbstractThunk, b::AbstractThunk) = unthunk(a) / unthunk(b)
40+
41+
# Fix method ambiguity issue
42+
Base.:/(a::AbstractZero, ::AbstractThunk) = a
3843

3944
Base.real(a::AbstractThunk) = real(unthunk(a))
4045
Base.imag(a::AbstractThunk) = imag(unthunk(a))

test/tangent_types/notimplemented.jl

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
@test ni + ni2 === ni
2222
@test ni2 + ni === ni2
2323

24-
# multiplication and dot product
24+
# multiplication, division, and dot product
2525
@test -ni == ni
2626
for a in (true, x, thunk)
2727
@test ni * a === ni
@@ -32,6 +32,7 @@
3232
for a in (NoTangent(), ZeroTangent())
3333
@test ni * a === a
3434
@test a * ni === a
35+
@test a / ni === a
3536
@test dot(ni, a) === a
3637
@test dot(a, ni) === a
3738
end
@@ -52,8 +53,10 @@
5253
@test_throws E a - ni
5354
end
5455
@test_throws E ni - ni2
55-
@test_throws E ni / x
56-
@test_throws E x / ni
56+
for a in (true, x, thunk)
57+
@test_throws E ni / a
58+
@test_throws E a / ni
59+
end
5760
@test_throws E ni / ni2
5861
@test_throws E zero(ni)
5962
@test_throws E zero(typeof(ni))

test/tangent_types/thunks.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,8 +101,15 @@
101101
@test 1 == -@thunk(-1)
102102
@test 1 == @thunk(2) - 1
103103
@test 1 == 2 - @thunk(1)
104+
@test 1 == @thunk(2) - @thunk(1)
104105
@test 1.0 == @thunk(1) / 1.0
105106
@test 1.0 == 1.0 / @thunk(1)
107+
@test 1 == @thunk(1) / @thunk(1)
108+
109+
# check method ambiguities (#589)
110+
for a in (ZeroTangent(), NoTangent())
111+
@test a / @thunk(2) === a
112+
end
106113

107114
@test 1 == real(@thunk(1 + 1im))
108115
@test 1 == imag(@thunk(1 + 1im))

0 commit comments

Comments
 (0)