Skip to content

Commit a133468

Browse files
authored
Merge pull request #30 from JuliaDiff/ox/wrtfunction
Overhaul Rules
2 parents ad1a7a4 + e51ff80 commit a133468

12 files changed

+509
-396
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "ChainRulesCore"
22
uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
3-
version = "0.2.1-DEV"
3+
version = "0.3.0"
44

55
[compat]
66
julia = "^1.0"

src/ChainRulesCore.jl

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,16 @@
11
module ChainRulesCore
22
using Base.Broadcast: materialize, materialize!, broadcasted, Broadcasted, broadcastable
33

4-
export AbstractRule, Rule, frule, rrule
4+
export frule, rrule
5+
export wirtinger_conjugate, wirtinger_primal, refine_differential
56
export @scalar_rule, @thunk
6-
export extern, cast, store!, Wirtinger, Zero, One, Casted, DNE, Thunk, DNERule
7+
export extern, cast, store!
8+
export Wirtinger, Zero, One, Casted, DNE, Thunk, InplaceableThunk
9+
export NO_FIELDS
710

811
include("differentials.jl")
912
include("differential_arithmetic.jl")
10-
include("rule_types.jl")
13+
include("operations.jl")
1114
include("rules.jl")
1215
include("rule_definition_tools.jl")
1316
end # module

src/differential_arithmetic.jl

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ subtypes, as we know the full set that might be encountered.
77
Thus we can avoid any ambiguities.
88
99
Notice:
10-
The precidence goes: (:Wirtinger, :Casted, :Zero, :DNE, :One, :Thunk, :Any)
10+
The precidence goes: (:Wirtinger, :Casted, :Zero, :DNE, :One, :AbstractThunk, :Any)
1111
Thus each of the @eval loops creating definitions of + and *
1212
defines the combination this type with all types of lower precidence.
1313
This means each eval loops is 1 item smaller than the previous.
@@ -36,7 +36,7 @@ function Base.:+(a::Wirtinger, b::Wirtinger)
3636
return Wirtinger(+(a.primal, b.primal), a.conjugate + b.conjugate)
3737
end
3838

39-
for T in (:Casted, :Zero, :DNE, :One, :Thunk, :Any)
39+
for T in (:Casted, :Zero, :DNE, :One, :AbstractThunk, :Any)
4040
@eval Base.:+(a::Wirtinger, b::$T) = a + Wirtinger(b, Zero())
4141
@eval Base.:+(a::$T, b::Wirtinger) = Wirtinger(a, Zero()) + b
4242

@@ -47,7 +47,7 @@ end
4747

4848
Base.:+(a::Casted, b::Casted) = Casted(broadcasted(+, a.value, b.value))
4949
Base.:*(a::Casted, b::Casted) = Casted(broadcasted(*, a.value, b.value))
50-
for T in (:Zero, :DNE, :One, :Thunk, :Any)
50+
for T in (:Zero, :DNE, :One, :AbstractThunk, :Any)
5151
@eval Base.:+(a::Casted, b::$T) = Casted(broadcasted(+, a.value, b))
5252
@eval Base.:+(a::$T, b::Casted) = Casted(broadcasted(+, a, b.value))
5353

@@ -58,7 +58,7 @@ end
5858

5959
Base.:+(::Zero, b::Zero) = Zero()
6060
Base.:*(::Zero, ::Zero) = Zero()
61-
for T in (:DNE, :One, :Thunk, :Any)
61+
for T in (:DNE, :One, :AbstractThunk, :Any)
6262
@eval Base.:+(::Zero, b::$T) = b
6363
@eval Base.:+(a::$T, ::Zero) = a
6464

@@ -69,7 +69,7 @@ end
6969

7070
Base.:+(::DNE, ::DNE) = DNE()
7171
Base.:*(::DNE, ::DNE) = DNE()
72-
for T in (:One, :Thunk, :Any)
72+
for T in (:One, :AbstractThunk, :Any)
7373
@eval Base.:+(::DNE, b::$T) = b
7474
@eval Base.:+(a::$T, ::DNE) = a
7575

@@ -80,7 +80,7 @@ end
8080

8181
Base.:+(a::One, b::One) = extern(a) + extern(b)
8282
Base.:*(::One, ::One) = One()
83-
for T in (:Thunk, :Any)
83+
for T in (:AbstractThunk, :Any)
8484
@eval Base.:+(a::One, b::$T) = extern(a) + b
8585
@eval Base.:+(a::$T, b::One) = a + extern(b)
8686

@@ -89,12 +89,12 @@ for T in (:Thunk, :Any)
8989
end
9090

9191

92-
Base.:+(a::Thunk, b::Thunk) = extern(a) + extern(b)
93-
Base.:*(a::Thunk, b::Thunk) = extern(a) * extern(b)
94-
for T in (:Any,) #This loop is redundant but for consistency...
95-
@eval Base.:+(a::Thunk, b::$T) = extern(a) + b
96-
@eval Base.:+(a::$T, b::Thunk) = a + extern(b)
92+
Base.:+(a::AbstractThunk, b::AbstractThunk) = extern(a) + extern(b)
93+
Base.:*(a::AbstractThunk, b::AbstractThunk) = extern(a) * extern(b)
94+
for T in (:Any,)
95+
@eval Base.:+(a::AbstractThunk, b::$T) = extern(a) + b
96+
@eval Base.:+(a::$T, b::AbstractThunk) = a + extern(b)
9797

98-
@eval Base.:*(a::Thunk, b::$T) = extern(a) * b
99-
@eval Base.:*(a::$T, b::Thunk) = a * extern(b)
98+
@eval Base.:*(a::AbstractThunk, b::$T) = extern(a) * b
99+
@eval Base.:*(a::$T, b::AbstractThunk) = a * extern(b)
100100
end

src/differentials.jl

Lines changed: 87 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,24 @@ Base.iterate(x::One) = (x, nothing)
173173
Base.iterate(::One, ::Any) = nothing
174174

175175

176+
#####
177+
##### `AbstractThunk
178+
#####
179+
abstract type AbstractThunk <: AbstractDifferential end
180+
181+
Base.Broadcast.broadcastable(x::AbstractThunk) = broadcastable(extern(x))
182+
183+
@inline function Base.iterate(x::AbstractThunk)
184+
externed = extern(x)
185+
element, state = iterate(externed)
186+
return element, (externed, state)
187+
end
188+
189+
@inline function Base.iterate(::AbstractThunk, (externed, state))
190+
element, new_state = iterate(externed, state)
191+
return element, (externed, new_state)
192+
end
193+
176194
#####
177195
##### `Thunk`
178196
#####
@@ -181,8 +199,9 @@ Base.iterate(::One, ::Any) = nothing
181199
Thunk(()->v)
182200
A thunk is a deferred computation.
183201
It wraps a zero argument closure that when invoked returns a differential.
202+
`@thunk(v)` is a macro that expands into `Thunk(()->v)`.
184203
185-
Calling that thunk, calls the wrapped closure.
204+
Calling a thunk, calls the wrapped closure.
186205
`extern`ing thunks applies recursively, it also externs the differial that the closure returns.
187206
If you do not want that, then simply call the thunk
188207
@@ -199,31 +218,87 @@ Thunk(var"##8#10"())
199218
julia> t()()
200219
3
201220
```
221+
222+
### When to `@thunk`?
223+
When writing `rrule`s (and to a lesser exent `frule`s), it is important to `@thunk`
224+
appropriately.
225+
Propagation rule's that return multiple derivatives are not able to do all the computing themselves.
226+
By `@thunk`ing the work required for each, they then compute only what is needed.
227+
228+
#### So why not thunk everything?
229+
`@thunk` creates a closure over the expression, which (effectively) creates a `struct`
230+
with a field for each variable used in the expression, and call overloaded.
231+
232+
Do not use `@thunk` if this would be equal or more work than actually evaluating the expression itself. Examples being:
233+
- The expression wrapping something in a `struct`, such as `Adjoint(x)` or `Diagonal(x)`
234+
- The expression being a constant
235+
- The expression being itself a `thunk`
236+
- The expression being from another `rrule` or `frule` (it would be `@thunk`ed if required by the defining rule already)
202237
"""
203-
struct Thunk{F} <: AbstractDifferential
238+
struct Thunk{F} <: AbstractThunk
204239
f::F
205240
end
206241

207242
macro thunk(body)
208243
return :(Thunk(() -> $(esc(body))))
209244
end
210245

246+
# have to define this here after `@thunk` and `Thunk` is defined
247+
Base.conj(x::AbstractThunk) = @thunk(conj(extern(x)))
248+
249+
211250
(x::Thunk)() = x.f()
212251
@inline extern(x::Thunk) = extern(x())
213252

214-
Base.Broadcast.broadcastable(x::Thunk) = broadcastable(extern(x))
253+
Base.show(io::IO, x::Thunk) = println(io, "Thunk($(repr(x.f)))")
215254

216-
@inline function Base.iterate(x::Thunk)
217-
externed = extern(x)
218-
element, state = iterate(externed)
219-
return element, (externed, state)
255+
"""
256+
InplaceableThunk(val::Thunk, add!::Function)
257+
258+
A wrapper for a `Thunk`, that allows it to define an inplace `add!` function,
259+
which is used internally in `accumulate!(Δ, ::InplaceableThunk)`.
260+
261+
`add!` should be defined such that: `ithunk.add!(Δ) = Δ .+= ithunk.val`
262+
but it should do this more efficently than simply doing this directly.
263+
(Otherwise one can just use a normal `Thunk`).
264+
265+
Most operations on an `InplaceableThunk` treat it just like a normal `Thunk`;
266+
and destroy its inplacability.
267+
"""
268+
struct InplaceableThunk{T<:Thunk, F} <: AbstractThunk
269+
val::T
270+
add!::F
220271
end
221272

222-
@inline function Base.iterate(::Thunk, (externed, state))
223-
element, new_state = iterate(externed, state)
224-
return element, (externed, new_state)
273+
(x::InplaceableThunk)() = x.val()
274+
@inline extern(x::InplaceableThunk) = extern(x.val)
275+
276+
function Base.show(io::IO, x::InplaceableThunk)
277+
println(io, "InplaceableThunk($(repr(x.val)), $(repr(x.add!)))")
225278
end
226279

227-
Base.conj(x::Thunk) = @thunk(conj(extern(x)))
280+
# The real reason we have this:
281+
accumulate!(Δ, ∂::InplaceableThunk) =.add!(Δ)
282+
store!(Δ, ∂::InplaceableThunk) =.add!((Δ.*=false)) # zero it, then add to it.
228283

229-
Base.show(io::IO, x::Thunk) = println(io, "Thunk($(repr(x.f)))")
284+
"""
285+
NO_FIELDS
286+
287+
Constant for the reverse-mode derivative with respect to a structure that has no fields.
288+
The most notable use for this is for the reverse-mode derivative with respect to the
289+
function itself, when that function is not a closure.
290+
"""
291+
const NO_FIELDS = DNE()
292+
293+
"""
294+
refine_differential(𝒟::Type, der)
295+
296+
Converts, if required, a differential object `der`
297+
(e.g. a `Number`, `AbstractDifferential`, `Matrix`, etc.),
298+
to another differential that is more suited for the domain given by the type 𝒟.
299+
Often this will behave as the identity function on `der`.
300+
"""
301+
function refine_differential(::Type{<:Union{<:Real, AbstractArray{<:Real}}}, w::Wirtinger)
302+
return wirtinger_primal(w) + wirtinger_conjugate(w)
303+
end
304+
refine_differential(::Any, der) = der # most of the time leave it alone.

src/operations.jl

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
# TODO: This all needs a fair bit of rethinking
2+
3+
"""
4+
accumulate(Δ, ∂)
5+
6+
Return `Δ + ∂` evaluated in a manner that supports ChainRulesCore's
7+
various `AbstractDifferential` types.
8+
9+
See also: [`accumulate!`](@ref), [`store!`](@ref), [`AbstractRule`](@ref)
10+
"""
11+
accumulate(Δ, ∂) = Δ .+
12+
13+
"""
14+
accumulate!(Δ, ∂)
15+
16+
Similar to [`accumulate`](@ref), but attempts to compute `Δ + rule(args...)` in-place,
17+
storing the result in `Δ`.
18+
19+
Note: this function may not actually store the result in `Δ` if `Δ` is immutable,
20+
so it is best to always call this as `Δ = accumulate!(Δ, ∂)` just in-case.
21+
22+
This function is overloadable by using a [`InplaceThunk`](@ref).
23+
See also: [`accumulate`](@ref), [`store!`](@ref).
24+
"""
25+
function accumulate!(Δ, ∂)
26+
return materialize!(Δ, broadcastable(cast(Δ) + ∂))
27+
end
28+
29+
accumulate!::Number, ∂) = accumulate(Δ, ∂)
30+
31+
32+
33+
"""
34+
store!(Δ, ∂)
35+
36+
Stores `∂`, in `Δ`, overwriting what ever was in `Δ` before.
37+
potentially avoiding intermediate temporary allocations that might be
38+
necessary for alternative approaches (e.g. `copyto!(Δ, extern(∂))`)
39+
40+
Like [`accumulate`](@ref) and [`accumulate!`](@ref), this function is intended
41+
to be customizable for specific rules/input types.
42+
43+
See also: [`accumulate`](@ref), [`accumulate!`](@ref), [`AbstractRule`](@ref)
44+
"""
45+
store!(Δ, ∂) = materialize!(Δ, broadcastable(∂))

0 commit comments

Comments
 (0)