Skip to content

Commit e74b156

Browse files
committed
Reorganize all the files
1 parent 18eea00 commit e74b156

File tree

15 files changed

+641
-22
lines changed

15 files changed

+641
-22
lines changed

src/ChainRulesCore.jl

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,14 @@ export extern, store!, unthunk
88
export Composite, DoesNotExist, InplaceableThunk, One, Thunk, Wirtinger, Zero
99
export NO_FIELDS
1010

11-
include("differentials.jl")
12-
include("composite_core.jl")
11+
include("differentials/abstract_differential.jl")
12+
include("differentials/wirtinger.jl")
13+
include("differentials/zero.jl")
14+
include("differentials/does_not_exist.jl")
15+
include("differentials/one.jl")
16+
include("differentials/thunks.jl")
17+
include("differentials/composite.jl")
18+
1319
include("differential_arithmetic.jl")
1420

1521
include("operations.jl")
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
#####
2+
##### `AbstractDifferential`
3+
#####
4+
5+
"""
6+
The subtypes of `AbstractDifferential` define a custom \"algebra\" for chain
7+
rule evaluation that attempts to factor various features like complex derivative
8+
support, broadcast fusion, zero-elision, etc. into nicely separated parts.
9+
10+
All subtypes of `AbstractDifferential` implement the following operations:
11+
12+
`+(a, b)`: linearly combine differential `a` and differential `b`
13+
14+
`*(a, b)`: multiply the differential `a` by the differential `b`
15+
16+
`Base.conj(x)`: complex conjugate of the differential `x`
17+
18+
`extern(x)`: convert `x` into an appropriate non-`AbstractDifferential` type for
19+
use outside of `ChainContext`.
20+
21+
Valid arguments to these operations are `T` where `T<:AbstractDifferential`, or
22+
where `T` has proper `+` and `*` implementations.
23+
24+
Additionally, all subtypes of `AbstractDifferential` support `Base.iterate` and
25+
`Base.Broadcast.broadcastable(x)`.
26+
"""
27+
abstract type AbstractDifferential end
28+
29+
Base.:+(x::AbstractDifferential) = x
30+
31+
"""
32+
extern(x)
33+
34+
Return `x` converted to an appropriate non-`AbstractDifferential` type, for use
35+
with external packages that might not handle `AbstractDifferential` types.
36+
37+
Note that this function may return an alias (not necessarily a copy) to data
38+
wrapped by `x`, such that mutating `extern(x)` might mutate `x` itself.
39+
"""
40+
@inline extern(x) = x
41+
42+
@inline Base.conj(x::AbstractDifferential) = x
43+
44+
"""
45+
refine_differential(𝒟::Type, der)
46+
47+
Converts, if required, a differential object `der`
48+
(e.g. a `Number`, `AbstractDifferential`, `Matrix`, etc.),
49+
to another differential that is more suited for the domain given by the type 𝒟.
50+
Often this will behave as the identity function on `der`.
51+
"""
52+
refine_differential(::Any, der) = der # most of the time leave it alone.

src/composite_core.jl renamed to src/differentials/composite.jl

Lines changed: 94 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,68 @@
1-
struct PrimalAdditionFailedException{P} <: Exception
2-
primal::P
3-
differential::Composite{P}
4-
original::Exception
1+
"""
2+
Composite{P, T} <: AbstractDifferential
3+
4+
This type represents the differential for a `struct`/`NamedTuple`, or `Tuple`.
5+
`P` is the the corresponding primal type that this is a differential for.
6+
7+
`Composite{P}` should have fields (technically properties), that match to a subset of the
8+
fields of the primal type; and each should be a differential type matching to the primal
9+
type of that field.
10+
Fields of the P that are not present in the Composite are treated as `Zero`.
11+
12+
`T` is an implementation detail representing the backing data structure.
13+
For Tuple it will be a Tuple, and for everything else it will be a `NamedTuple`.
14+
It should not be passed in by user.
15+
"""
16+
struct Composite{P, T} <: AbstractDifferential
17+
# Note: If T is a Tuple, then P is also a Tuple
18+
# (but potentially a different one, as it doesn't contain differentials)
19+
backing::T
520
end
621

7-
function Base.showerror(io::IO, err::PrimalAdditionFailedException{P}) where {P}
8-
println(io, "Could not construct $P after addition.")
9-
println(io, "This probably means no default constructor is defined.")
10-
println(io, "Either define a default constructor")
11-
printstyled(io, "$P(", join(propertynames(err.differential), ", "), ")", color=:blue)
12-
println(io, "\nor overload")
13-
printstyled(io,
14-
"ChainRulesCore.construct(::Type{$P}, ::$(typeof(err.differential)))";
15-
color=:blue
16-
)
17-
println(io, "\nor overload")
18-
printstyled(io, "Base.:+(::$P, ::$(typeof(err.differential)))"; color=:blue)
19-
println(io, "\nOriginal Exception:")
20-
printstyled(io, err.original; color=:yellow)
21-
println(io)
22+
function Composite{P}(; kwargs...) where P
23+
backing = (; kwargs...) # construct as NamedTuple
24+
return Composite{P, typeof(backing)}(backing)
25+
end
26+
27+
function Composite{P}(args...) where P
28+
return Composite{P, typeof(args)}(args)
29+
end
30+
31+
function Base.show(io::IO, comp::Composite{P}) where P
32+
print(io, "Composite{")
33+
show(io, P)
34+
print(io, "}")
35+
# allow Tuple or NamedTuple `show` to do the rendering of brackets etc
36+
show(io, backing(comp))
2237
end
2338

39+
Base.convert(::Type{<:NamedTuple}, comp::Composite{<:Any, <:NamedTuple}) = backing(comp)
40+
Base.convert(::Type{<:Tuple}, comp::Composite{<:Any, <:Tuple}) = backing(comp)
41+
42+
Base.getindex(comp::Composite, idx) = getindex(backing(comp), idx)
43+
Base.getproperty(comp::Composite, idx::Int) = getproperty(backing(comp), idx) # for Tuple
44+
Base.getproperty(comp::Composite, idx::Symbol) = getproperty(backing(comp), idx)
45+
Base.propertynames(comp::Composite) = propertynames(backing(comp))
46+
47+
Base.iterate(comp::Composite, args...) = iterate(backing(comp), args...)
48+
Base.length(comp::Composite) = length(backing(comp))
49+
Base.eltype(::Type{<:Composite{<:Any, T}}) where T = eltype(T)
50+
51+
function Base.map(f, comp::Composite{P, <:Tuple}) where P
52+
vals::Tuple = map(f, backing(comp))
53+
return Composite{P, typeof(vals)}(vals)
54+
end
55+
function Base.map(f, comp::Composite{P, <:NamedTuple{L}}) where{P, L}
56+
vals = map(f, Tuple(backing(comp)))
57+
named_vals = NamedTuple{L, typeof(vals)}(vals)
58+
return Composite{P, typeof(named_vals)}(named_vals)
59+
end
60+
61+
Base.conj(comp::Composite) = map(conj, comp)
62+
63+
extern(comp::Composite) = backing(map(extern, comp)) # gives a NamedTuple or Tuple
64+
65+
2466
"""
2567
backing(x)
2668
@@ -131,3 +173,36 @@ function elementwise_add(a::NamedTuple{an}, b::NamedTuple{bn}) where {an, bn}
131173
return NamedTuple{names,types}(vals)
132174
end
133175
end
176+
177+
178+
struct PrimalAdditionFailedException{P} <: Exception
179+
primal::P
180+
differential::Composite{P}
181+
original::Exception
182+
end
183+
184+
function Base.showerror(io::IO, err::PrimalAdditionFailedException{P}) where {P}
185+
println(io, "Could not construct $P after addition.")
186+
println(io, "This probably means no default constructor is defined.")
187+
println(io, "Either define a default constructor")
188+
printstyled(io, "$P(", join(propertynames(err.differential), ", "), ")", color=:blue)
189+
println(io, "\nor overload")
190+
printstyled(io,
191+
"ChainRulesCore.construct(::Type{$P}, ::$(typeof(err.differential)))";
192+
color=:blue
193+
)
194+
println(io, "\nor overload")
195+
printstyled(io, "Base.:+(::$P, ::$(typeof(err.differential)))"; color=:blue)
196+
println(io, "\nOriginal Exception:")
197+
printstyled(io, err.original; color=:yellow)
198+
println(io)
199+
end
200+
201+
"""
202+
NO_FIELDS
203+
204+
Constant for the reverse-mode derivative with respect to a structure that has no fields.
205+
The most notable use for this is for the reverse-mode derivative with respect to the
206+
function itself, when that function is not a closure.
207+
"""
208+
const NO_FIELDS = DoesNotExist()

src/differentials/does_not_exist.jl

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
"""
2+
DoesNotExist()
3+
4+
This differential indicates that the derivative Does Not Exist (D.N.E).
5+
This is not the cast that it is not implemented, but rather that it mathematically
6+
is not defined.
7+
"""
8+
struct DoesNotExist <: AbstractDifferential end
9+
10+
function extern(x::DoesNotExist)
11+
throw(ArgumentError("Derivative does not exit. Cannot be converted to an external type."))
12+
end
13+
14+
Base.Broadcast.broadcastable(::DoesNotExist) = Ref(DoesNotExist())
15+
16+
Base.iterate(x::DoesNotExist) = (x, nothing)
17+
Base.iterate(::DoesNotExist, ::Any) = nothing
18+

src/differentials/one.jl

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
"""
2+
One()
3+
The Differential which is the multiplicative identity.
4+
Basically, this represents `1`.
5+
"""
6+
struct One <: AbstractDifferential end
7+
8+
extern(x::One) = true # true is a strong 1.
9+
10+
Base.Broadcast.broadcastable(::One) = Ref(One())
11+
12+
Base.iterate(x::One) = (x, nothing)
13+
Base.iterate(::One, ::Any) = nothing
14+

src/differentials/thunks.jl

Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
2+
abstract type AbstractThunk <: AbstractDifferential end
3+
4+
Base.Broadcast.broadcastable(x::AbstractThunk) = broadcastable(extern(x))
5+
6+
@inline function Base.iterate(x::AbstractThunk)
7+
externed = extern(x)
8+
element, state = iterate(externed)
9+
return element, (externed, state)
10+
end
11+
12+
@inline function Base.iterate(::AbstractThunk, (externed, state))
13+
element, new_state = iterate(externed, state)
14+
return element, (externed, new_state)
15+
end
16+
17+
#####
18+
##### `Thunk`
19+
#####
20+
21+
"""
22+
Thunk(()->v)
23+
A thunk is a deferred computation.
24+
It wraps a zero argument closure that when invoked returns a differential.
25+
`@thunk(v)` is a macro that expands into `Thunk(()->v)`.
26+
27+
Calling a thunk, calls the wrapped closure.
28+
`extern`ing thunks applies recursively, it also externs the differial that the closure returns.
29+
If you do not want that, then simply call the thunk
30+
31+
```
32+
julia> t = @thunk(@thunk(3))
33+
Thunk(var"##7#9"())
34+
35+
julia> extern(t)
36+
3
37+
38+
julia> t()
39+
Thunk(var"##8#10"())
40+
41+
julia> t()()
42+
3
43+
```
44+
45+
### When to `@thunk`?
46+
When writing `rrule`s (and to a lesser exent `frule`s), it is important to `@thunk`
47+
appropriately.
48+
Propagation rule's that return multiple derivatives are not able to do all the computing themselves.
49+
By `@thunk`ing the work required for each, they then compute only what is needed.
50+
51+
#### So why not thunk everything?
52+
`@thunk` creates a closure over the expression, which (effectively) creates a `struct`
53+
with a field for each variable used in the expression, and call overloaded.
54+
55+
Do not use `@thunk` if this would be equal or more work than actually evaluating the expression itself. Examples being:
56+
- The expression wrapping something in a `struct`, such as `Adjoint(x)` or `Diagonal(x)`
57+
- The expression being a constant
58+
- The expression being itself a `thunk`
59+
- The expression being from another `rrule` or `frule` (it would be `@thunk`ed if required by the defining rule already)
60+
"""
61+
struct Thunk{F} <: AbstractThunk
62+
f::F
63+
end
64+
65+
66+
"""
67+
@thunk expr
68+
69+
Define a [`Thunk`](@ref) wrapping the `expr`, to lazily defer its evaluation.
70+
"""
71+
macro thunk(body)
72+
# Basically `:(Thunk(() -> $(esc(body))))` but use the location where it is defined.
73+
# so we get useful stack traces if it errors.
74+
func = Expr(:->, Expr(:tuple), Expr(:block, __source__, body))
75+
return :(Thunk($(esc(func))))
76+
end
77+
78+
"""
79+
unthunk(x)
80+
81+
On `AbstractThunk`s this removes 1 layer of thunking.
82+
On any other type, it is the identity operation.
83+
84+
In contrast to `extern` this is nonrecursive.
85+
"""
86+
@inline unthunk(x) = x
87+
88+
@inline extern(x::AbstractThunk) = extern(unthunk(x))
89+
90+
# have to define this here after `@thunk` and `Thunk` is defined
91+
Base.conj(x::AbstractThunk) = @thunk(conj(unthunk(x)))
92+
93+
94+
(x::Thunk)() = x.f()
95+
@inline unthunk(x::Thunk) = x()
96+
97+
Base.show(io::IO, x::Thunk) = println(io, "Thunk($(repr(x.f)))")
98+
99+
"""
100+
InplaceableThunk(val::Thunk, add!::Function)
101+
102+
A wrapper for a `Thunk`, that allows it to define an inplace `add!` function,
103+
which is used internally in `accumulate!(Δ, ::InplaceableThunk)`.
104+
105+
`add!` should be defined such that: `ithunk.add!(Δ) = Δ .+= ithunk.val`
106+
but it should do this more efficently than simply doing this directly.
107+
(Otherwise one can just use a normal `Thunk`).
108+
109+
Most operations on an `InplaceableThunk` treat it just like a normal `Thunk`;
110+
and destroy its inplacability.
111+
"""
112+
struct InplaceableThunk{T<:Thunk, F} <: AbstractThunk
113+
val::T
114+
add!::F
115+
end
116+
117+
unthunk(x::InplaceableThunk) = unthunk(x.val)
118+
(x::InplaceableThunk)() = unthunk(x)
119+
120+
function Base.show(io::IO, x::InplaceableThunk)
121+
println(io, "InplaceableThunk($(repr(x.val)), $(repr(x.add!)))")
122+
end
123+
124+
# The real reason we have this:
125+
accumulate!(Δ, ∂::InplaceableThunk) =.add!(Δ)
126+
store!(Δ, ∂::InplaceableThunk) =.add!((Δ.*=false)) # zero it, then add to it.

0 commit comments

Comments
 (0)