Skip to content

Commit ef748c9

Browse files
committed
Get rid of all Rule types, add InplacableThunk
spelling is hard
1 parent f14e045 commit ef748c9

File tree

7 files changed

+104
-299
lines changed

7 files changed

+104
-299
lines changed

src/ChainRulesCore.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,16 @@
11
module ChainRulesCore
22
using Base.Broadcast: materialize, materialize!, broadcasted, Broadcasted, broadcastable
33

4-
export AbstractRule, Rule, frule, rrule
4+
export frule, rrule
55
export wirtinger_conjugate, wirtinger_primal, differential
66
export @scalar_rule, @thunk
77
export extern, cast, store!
8-
export Wirtinger, Zero, One, Casted, DNE, Thunk, DNERule
8+
export Wirtinger, Zero, One, Casted, DNE, Thunk, InplaceableThunk
99
export NO_FIELDS
1010

1111
include("differentials.jl")
1212
include("differential_arithmetic.jl")
13-
include("rule_types.jl")
13+
include("operations.jl")
1414
include("rules.jl")
1515
include("rule_definition_tools.jl")
1616
end # module

src/differential_arithmetic.jl

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ subtypes, as we know the full set that might be encountered.
77
Thus we can avoid any ambiguities.
88
99
Notice:
10-
The precidence goes: (:Wirtinger, :Casted, :Zero, :DNE, :One, :Thunk, :Any)
10+
The precidence goes: (:Wirtinger, :Casted, :Zero, :DNE, :One, :Thunk, :InplaceableThunk, :Any)
1111
Thus each of the @eval loops creating definitions of + and *
1212
defines the combination this type with all types of lower precidence.
1313
This means each eval loops is 1 item smaller than the previous.
@@ -36,7 +36,7 @@ function Base.:+(a::Wirtinger, b::Wirtinger)
3636
return Wirtinger(+(a.primal, b.primal), a.conjugate + b.conjugate)
3737
end
3838

39-
for T in (:Casted, :Zero, :DNE, :One, :Thunk, :Any)
39+
for T in (:Casted, :Zero, :DNE, :One, :Thunk, :InplaceableThunk, :Any)
4040
@eval Base.:+(a::Wirtinger, b::$T) = a + Wirtinger(b, Zero())
4141
@eval Base.:+(a::$T, b::Wirtinger) = Wirtinger(a, Zero()) + b
4242

@@ -47,7 +47,7 @@ end
4747

4848
Base.:+(a::Casted, b::Casted) = Casted(broadcasted(+, a.value, b.value))
4949
Base.:*(a::Casted, b::Casted) = Casted(broadcasted(*, a.value, b.value))
50-
for T in (:Zero, :DNE, :One, :Thunk, :Any)
50+
for T in (:Zero, :DNE, :One, :Thunk, :InplaceableThunk, :Any)
5151
@eval Base.:+(a::Casted, b::$T) = Casted(broadcasted(+, a.value, b))
5252
@eval Base.:+(a::$T, b::Casted) = Casted(broadcasted(+, a, b.value))
5353

@@ -58,7 +58,7 @@ end
5858

5959
Base.:+(::Zero, b::Zero) = Zero()
6060
Base.:*(::Zero, ::Zero) = Zero()
61-
for T in (:DNE, :One, :Thunk, :Any)
61+
for T in (:DNE, :One, :Thunk, :InplaceableThunk, :Any)
6262
@eval Base.:+(::Zero, b::$T) = b
6363
@eval Base.:+(a::$T, ::Zero) = a
6464

@@ -69,7 +69,7 @@ end
6969

7070
Base.:+(::DNE, ::DNE) = DNE()
7171
Base.:*(::DNE, ::DNE) = DNE()
72-
for T in (:One, :Thunk, :Any)
72+
for T in (:One, :Thunk, :InplaceableThunk, :Any)
7373
@eval Base.:+(::DNE, b::$T) = b
7474
@eval Base.:+(a::$T, ::DNE) = a
7575

@@ -80,7 +80,7 @@ end
8080

8181
Base.:+(a::One, b::One) = extern(a) + extern(b)
8282
Base.:*(::One, ::One) = One()
83-
for T in (:Thunk, :Any)
83+
for T in (:Thunk, :InplaceableThunk, :Any)
8484
@eval Base.:+(a::One, b::$T) = extern(a) + b
8585
@eval Base.:+(a::$T, b::One) = a + extern(b)
8686

@@ -91,10 +91,21 @@ end
9191

9292
Base.:+(a::Thunk, b::Thunk) = extern(a) + extern(b)
9393
Base.:*(a::Thunk, b::Thunk) = extern(a) * extern(b)
94-
for T in (:Any,) #This loop is redundant but for consistency...
94+
for T in (:InplaceableThunk, :Any)
9595
@eval Base.:+(a::Thunk, b::$T) = extern(a) + b
9696
@eval Base.:+(a::$T, b::Thunk) = a + extern(b)
9797

9898
@eval Base.:*(a::Thunk, b::$T) = extern(a) * b
9999
@eval Base.:*(a::$T, b::Thunk) = a * extern(b)
100100
end
101+
102+
# InplaceableThunk acts just like Thunk
103+
Base.:+(a::InplaceableThunk, b::InplaceableThunk) = extern(a) + extern(b)
104+
Base.:*(a::InplaceableThunk, b::InplaceableThunk) = extern(a) * extern(b)
105+
for T in (:Any, )
106+
@eval Base.:+(a::InplaceableThunk, b::$T) = extern(a) + b
107+
@eval Base.:+(a::$T, b::InplaceableThunk) = a + extern(b)
108+
109+
@eval Base.:*(a::InplaceableThunk, b::$T) = extern(a) * b
110+
@eval Base.:*(a::$T, b::InplaceableThunk) = a * extern(b)
111+
end

src/differentials.jl

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -246,6 +246,43 @@ Base.conj(x::Thunk) = @thunk(conj(extern(x)))
246246

247247
Base.show(io::IO, x::Thunk) = println(io, "Thunk($(repr(x.f)))")
248248

249+
"""
250+
InplaceableThunk(val::Thunk, add!::Function)
251+
252+
A wrapper for a `Thunk`, that allows it to define an inplace `add!` function,
253+
which is used internally in `accumulate!(Δ, ::InplaceableThunk)`.
254+
255+
`add!` should be defined such that: `ithunk.add!(Δ) = Δ .+= ithunk.val`
256+
but it should do this more efficently than simply doing this directly.
257+
(Otherwise one can just use a normal `Thunk`).
258+
259+
Most operations on an `InplaceableThunk` treat it just like a normal `Thunk`;
260+
and destroy its inplacability.
261+
"""
262+
struct InplaceableThunk{T<:Thunk, F} <: AbstractDifferential
263+
val::T
264+
add!::F
265+
end
266+
267+
(x::InplaceableThunk)() = x.val()
268+
@inline extern(x::InplaceableThunk) = extern(x.val)
269+
270+
Base.Broadcast.broadcastable(x::InplaceableThunk) = broadcastable(x.val)
271+
272+
@inline function Base.iterate(x::InplaceableThunk, args...)
273+
return iterate(x.val, args...)
274+
end
275+
276+
Base.conj(x::InplaceableThunk) = conj(x.val)
277+
278+
function Base.show(io::IO, x::InplaceableThunk)
279+
println(io, "InplaceableThunk($(repr(x.val)), $(repr(x.add!)))")
280+
end
281+
282+
# The real reason we have this:
283+
accumulate!(Δ, ∂::InplaceableThunk) =.add!(Δ)
284+
285+
249286
"""
250287
NO_FIELDS
251288
@@ -255,7 +292,6 @@ function itself, when that function is not a closure.
255292
"""
256293
const NO_FIELDS = DNE()
257294

258-
####
259295
"""
260296
differential(𝒟::Type, der)
261297

src/operations.jl

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
# TODO: This all needs a fair bit of rethinking
2+
3+
"""
4+
accumulate(Δ, ∂)
5+
6+
Return `Δ + ∂` evaluated in a manner that supports ChainRulesCore's
7+
various `AbstractDifferential` types.
8+
9+
See also: [`accumulate!`](@ref), [`store!`](@ref), [`AbstractRule`](@ref)
10+
"""
11+
accumulate(Δ, ∂) = Δ .+
12+
13+
"""
14+
accumulate!(Δ, ∂)
15+
16+
Similar to [`accumulate`](@ref), but attempts to compute `Δ + rule(args...)` in-place,
17+
storing the result in `Δ`.
18+
19+
Note: this function may not actually store the result in `Δ` if `Δ` is immutable,
20+
so it is best to always call this as `Δ = accumulate!(Δ, ∂)` just in-case.
21+
22+
This function is overloadable by [`InplaceableThunk`s](@ref).
23+
See also: [`accumulate`](@ref), [`store!`](@ref), [`AbstractRule`](@ref)
24+
"""
25+
function accumulate!(Δ, ∂)
26+
return materialize!(Δ, broadcastable(cast(Δ) + ∂))
27+
end
28+
29+
accumulate!::Number, ∂) = accumulate(Δ, ∂)
30+
31+
32+
33+
"""
34+
store!(Δ, ∂)
35+
36+
Stores `∂`, in `Δ`, overwriting what ever was in `Δ` before.
37+
potentially avoiding intermediate temporary allocations that might be
38+
necessary for alternative approaches (e.g. `copyto!(Δ, extern(∂))`)
39+
40+
Like [`accumulate`](@ref) and [`accumulate!`](@ref), this function is intended
41+
to be customizable for specific rules/input types.
42+
43+
See also: [`accumulate`](@ref), [`accumulate!`](@ref), [`AbstractRule`](@ref)
44+
"""
45+
store!(Δ, ∂) = materialize!(Δ, broadcastable(∂))

0 commit comments

Comments
 (0)