Skip to content

Commit 0637def

Browse files
Make DoesNotExist more clearly like Zero (#86)
* Make DoesNotExist more clearly like Zero * Update src/differentials/does_not_exist.jl Co-Authored-By: willtebbutt <wt0881@my.bristol.ac.uk> * Fix use of Zero vs Zero() in tests Co-authored-by: willtebbutt <wt0881@my.bristol.ac.uk>
1 parent b9abd85 commit 0637def

File tree

4 files changed

+66
-16
lines changed

4 files changed

+66
-16
lines changed

src/differential_arithmetic.jl

Lines changed: 20 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -8,23 +8,12 @@ Thus we can avoid any ambiguities.
88
99
Notice:
1010
The precedence goes:
11-
`Zero, DoesNotExist, One, AbstractThunk, Composite, Any`
12-
Thus each of the @eval loops creating definitions of + and *
11+
`DoesNotExist, Zero, One, AbstractThunk, Composite, Any`
12+
Thus each of the @eval loops create most definitions of + and *
1313
defines the combination this type with all types of lower precidence.
1414
This means each eval loops is 1 item smaller than the previous.
1515
==#
1616

17-
Base.:+(::Zero, b::Zero) = Zero()
18-
Base.:*(::Zero, ::Zero) = Zero()
19-
for T in (:DoesNotExist, :One, :AbstractThunk, :Any)
20-
@eval Base.:+(::Zero, b::$T) = b
21-
@eval Base.:+(a::$T, ::Zero) = a
22-
23-
@eval Base.:*(::Zero, ::$T) = Zero()
24-
@eval Base.:*(::$T, ::Zero) = Zero()
25-
end
26-
27-
2817
Base.:+(::DoesNotExist, ::DoesNotExist) = DoesNotExist()
2918
Base.:*(::DoesNotExist, ::DoesNotExist) = DoesNotExist()
3019
for T in (:One, :AbstractThunk, :Any)
@@ -34,6 +23,24 @@ for T in (:One, :AbstractThunk, :Any)
3423
@eval Base.:*(::DoesNotExist, ::$T) = DoesNotExist()
3524
@eval Base.:*(::$T, ::DoesNotExist) = DoesNotExist()
3625
end
26+
# `DoesNotExist` and `Zero` have special relationship,
27+
# DoesNotExist wins add, Zero wins *. This is (in theory) to allow `*` to be used for
28+
# selecting things.
29+
Base.:+(::DoesNotExist, ::Zero) = DoesNotExist()
30+
Base.:+(::Zero, ::DoesNotExist) = DoesNotExist()
31+
Base.:*(::DoesNotExist, ::Zero) = Zero()
32+
Base.:*(::Zero, ::DoesNotExist) = Zero()
33+
34+
35+
Base.:+(::Zero, b::Zero) = Zero()
36+
Base.:*(::Zero, ::Zero) = Zero()
37+
for T in (:One, :AbstractThunk, :Any)
38+
@eval Base.:+(::Zero, b::$T) = b
39+
@eval Base.:+(a::$T, ::Zero) = a
40+
41+
@eval Base.:*(::Zero, ::$T) = Zero()
42+
@eval Base.:*(::$T, ::Zero) = Zero()
43+
end
3744

3845

3946
Base.:+(a::One, b::One) = extern(a) + extern(b)

src/differentials/does_not_exist.jl

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,29 @@
22
DoesNotExist()
33
44
This differential indicates that the derivative Does Not Exist (D.N.E).
5-
This is not the cast that it is not implemented, but rather that it mathematically
6-
is not defined.
5+
It is the differential for a Primal type that is not differentiable.
6+
Such an Integer, or Boolean (when not being used as a represention of a valid that normally
7+
would be a floating point.)
8+
The only valid way to pertube such a values is to not change it at all.
9+
As such, `DoesNotExist` is functionally identical to `Zero()`,
10+
but provides additional semantic information.
11+
12+
If you are adding this differential to a primal then something is wrong.
13+
A optimization package making use of this might like to check for such a case.
14+
15+
!!! note:
16+
This does not indicate that the derivative it is not implemented,
17+
but rather that mathematically it is not defined.
18+
19+
This mostly shows up as the deriviative with respect to dimension, index, or size
20+
arguments.
21+
```
22+
function rrule(fill, x, len::Int)
23+
y = fill(x, len)
24+
fill_pullback(ȳ) = (NO_FIELDS, @thunk(sum(Ȳ)), DoesNotExist())
25+
return y, fill_pullback
26+
end
27+
```
728
"""
829
struct DoesNotExist <: AbstractDifferential end
930

@@ -15,4 +36,3 @@ Base.Broadcast.broadcastable(::DoesNotExist) = Ref(DoesNotExist())
1536

1637
Base.iterate(x::DoesNotExist) = (x, nothing)
1738
Base.iterate(::DoesNotExist, ::Any) = nothing
18-

test/differentials/does_not_exist.jl

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
@testset "DoesNotExist" begin
2+
dne = DoesNotExist()
3+
@test_throws Exception extern(dne)
4+
@test dne + dne == dne
5+
@test dne + 1 == 1
6+
@test 1 + dne == 1
7+
@test dne * dne == dne
8+
@test dne * 1 == dne
9+
@test 1 * dne == dne
10+
11+
@test Zero() + dne == dne
12+
@test dne + Zero() == dne
13+
14+
@test Zero() * dne == Zero()
15+
@test dne * Zero() == Zero()
16+
17+
for x in dne
18+
@test x === dne
19+
end
20+
@test broadcastable(dne) isa Ref{DoesNotExist}
21+
@test conj(dne) == dne
22+
end

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ using Base.Broadcast: broadcastable
1010
@testset "ChainRulesCore" begin
1111
@testset "differentials" begin
1212
include("differentials/zero.jl")
13+
include("differentials/does_not_exist.jl")
1314
include("differentials/one.jl")
1415
include("differentials/thunks.jl")
1516
include("differentials/composite.jl")

0 commit comments

Comments
 (0)