Skip to content

Commit 52fb1e0

Browse files
Merge pull request #32 from JuliaDiff/ox/moveruletypes
Move rule types out the their own file
2 parents f34bf9c + 87a5b44 commit 52fb1e0

File tree

6 files changed

+292
-289
lines changed

6 files changed

+292
-289
lines changed

src/ChainRulesCore.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ export @scalar_rule, @thunk
77
export extern, cast, store!, Wirtinger, Zero, One, Casted, DNE, Thunk, DNERule
88

99
include("differentials.jl")
10+
include("rule_types.jl")
1011
include("rules.jl")
1112
include("rule_definition_tools.jl")
1213
end # module

src/rule_types.jl

Lines changed: 217 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,217 @@
1+
"""
2+
Subtypes of `AbstractRule` are types which represent the primitive derivative
3+
propagation "rules" that can be composed to implement forward- and reverse-mode
4+
automatic differentiation.
5+
6+
More specifically, a `rule::AbstractRule` is a callable Julia object generally
7+
obtained via calling [`frule`](@ref) or [`rrule`](@ref). Such rules accept
8+
differential values as input, evaluate the chain rule using internally stored/
9+
computed partial derivatives to produce a single differential value, then
10+
return that calculated differential value.
11+
12+
For example:
13+
14+
```jldoctest
15+
julia> using ChainRulesCore: frule, rrule, AbstractRule
16+
17+
julia> x, y = rand(2);
18+
19+
julia> h, dh = frule(hypot, x, y);
20+
21+
julia> h == hypot(x, y)
22+
true
23+
24+
julia> isa(dh, AbstractRule)
25+
true
26+
27+
julia> Δx, Δy = rand(2);
28+
29+
julia> dh(Δx, Δy) == ((x / h) * Δx + (y / h) * Δy)
30+
true
31+
32+
julia> h, (dx, dy) = rrule(hypot, x, y);
33+
34+
julia> h == hypot(x, y)
35+
true
36+
37+
julia> isa(dx, AbstractRule) && isa(dy, AbstractRule)
38+
true
39+
40+
julia> Δh = rand();
41+
42+
julia> dx(Δh) == (x / h) * Δh
43+
true
44+
45+
julia> dy(Δh) == (y / h) * Δh
46+
true
47+
```
48+
49+
See also: [`frule`](@ref), [`rrule`](@ref), [`Rule`](@ref), [`DNERule`](@ref), [`WirtingerRule`](@ref)
50+
"""
51+
abstract type AbstractRule end
52+
53+
# this ensures that consumers don't have to special-case rule destructuring
54+
Base.iterate(rule::AbstractRule) = (rule, nothing)
55+
Base.iterate(::AbstractRule, ::Any) = nothing
56+
57+
# This ensures we don't need to check whether the result of `rrule`/`frule` is a tuple
58+
# in order to get the `i`th rule (assuming it's 1)
59+
Base.getindex(rule::AbstractRule, i::Integer) = i == 1 ? rule : throw(BoundsError())
60+
61+
"""
62+
accumulate(Δ, rule::AbstractRule, args...)
63+
64+
Return `Δ + rule(args...)` evaluated in a manner that supports ChainRulesCore'
65+
various `AbstractDifferential` types.
66+
67+
This method intended to be customizable for specific rules/input types. For
68+
example, here is pseudocode to overload `accumulate` w.r.t. a specific forward
69+
differentiation rule for a given function `f`:
70+
71+
```
72+
df(x) = # forward differentiation primitive implementation
73+
74+
frule(::typeof(f), x) = (f(x), Rule(df))
75+
76+
accumulate(Δ, rule::Rule{typeof(df)}, x) = # customized `accumulate` implementation
77+
```
78+
79+
See also: [`accumulate!`](@ref), [`store!`](@ref), [`AbstractRule`](@ref)
80+
"""
81+
accumulate(Δ, rule::AbstractRule, args...) = add(Δ, rule(args...))
82+
83+
"""
84+
accumulate!(Δ, rule::AbstractRule, args...)
85+
86+
Similar to [`accumulate`](@ref), but compute `Δ + rule(args...)` in-place,
87+
storing the result in `Δ`.
88+
89+
Note that this function internally calls `Base.Broadcast.materialize!(Δ, ...)`.
90+
91+
See also: [`accumulate`](@ref), [`store!`](@ref), [`AbstractRule`](@ref)
92+
"""
93+
function accumulate!(Δ, rule::AbstractRule, args...)
94+
return materialize!(Δ, broadcastable(add(cast(Δ), rule(args...))))
95+
end
96+
97+
accumulate!::Number, rule::AbstractRule, args...) = accumulate(Δ, rule, args...)
98+
99+
"""
100+
store!(Δ, rule::AbstractRule, args...)
101+
102+
Compute `rule(args...)` and store the result in `Δ`, potentially avoiding
103+
intermediate temporary allocations that might be necessary for alternative
104+
approaches (e.g. `copyto!(Δ, extern(rule(args...)))`)
105+
106+
Note that this function internally calls `Base.Broadcast.materialize!(Δ, ...)`.
107+
108+
Like [`accumulate`](@ref) and [`accumulate!`](@ref), this function is intended
109+
to be customizable for specific rules/input types.
110+
111+
See also: [`accumulate`](@ref), [`accumulate!`](@ref), [`AbstractRule`](@ref)
112+
"""
113+
store!(Δ, rule::AbstractRule, args...) = materialize!(Δ, broadcastable(rule(args...)))
114+
115+
#####
116+
##### `Rule`
117+
#####
118+
119+
Cassette.@context RuleContext
120+
121+
const RULE_CONTEXT = Cassette.disablehooks(RuleContext())
122+
123+
Cassette.overdub(::RuleContext, ::typeof(+), a, b) = add(a, b)
124+
Cassette.overdub(::RuleContext, ::typeof(*), a, b) = mul(a, b)
125+
126+
Cassette.overdub(::RuleContext, ::typeof(add), a, b) = add(a, b)
127+
Cassette.overdub(::RuleContext, ::typeof(mul), a, b) = mul(a, b)
128+
129+
"""
130+
Rule(propation_function[, updating_function])
131+
132+
Return a `Rule` that wraps the given `propation_function`. It is assumed that
133+
`propation_function` is a callable object whose arguments are differential
134+
values, and whose output is a single differential value calculated by applying
135+
internally stored/computed partial derivatives to the input differential
136+
values.
137+
138+
If an updating function is provided, it is assumed to have the signature `u(Δ, xs...)`
139+
and to store the result of the propagation function applied to the arguments `xs` into
140+
`Δ` in-place, returning `Δ`.
141+
142+
For example:
143+
144+
```
145+
frule(::typeof(*), x, y) = x * y, Rule((Δx, Δy) -> Δx * y + x * Δy)
146+
147+
rrule(::typeof(*), x, y) = x * y, (Rule(ΔΩ -> ΔΩ * y'), Rule(ΔΩ -> x' * ΔΩ))
148+
```
149+
150+
See also: [`frule`](@ref), [`rrule`](@ref), [`accumulate`](@ref), [`accumulate!`](@ref), [`store!`](@ref)
151+
"""
152+
struct Rule{F,U<:Union{Function,Nothing}} <: AbstractRule
153+
f::F
154+
u::U
155+
end
156+
157+
# NOTE: Using `Core.Typeof` instead of `typeof` here so that if we define a rule for some
158+
# constructor based on a `UnionAll`, we get `Rule{Type{Thing}}` instead of `Rule{UnionAll}`
159+
Rule(f) = Rule{Core.Typeof(f),Nothing}(f, nothing)
160+
161+
(rule::Rule{F})(args...) where {F} = Cassette.overdub(RULE_CONTEXT, rule.f, args...)
162+
163+
# Specialized accumulation
164+
# TODO: Does this need to be overdubbed in the rule context?
165+
accumulate!(Δ, rule::Rule{F,U}, args...) where {F,U<:Function} = rule.u(Δ, args...)
166+
167+
#####
168+
##### `DNERule`
169+
#####
170+
171+
"""
172+
DNERule(args...)
173+
174+
Construct a `DNERule` object, which is an `AbstractRule` that signifies that the
175+
current function is not differentiable with respect to a particular parameter.
176+
**DNE** is an abbreviation for Does Not Exist.
177+
"""
178+
struct DNERule <: AbstractRule end
179+
180+
DNERule(args...) = DNE()
181+
182+
#####
183+
##### `WirtingerRule`
184+
#####
185+
186+
"""
187+
WirtingerRule(primal::AbstractRule, conjugate::AbstractRule)
188+
189+
Construct a `WirtingerRule` object, which is an `AbstractRule` that consists of
190+
an `AbstractRule` for both the primal derivative ``∂/∂x`` and the conjugate
191+
derivative ``∂/∂x̅``. If the domain `𝒟` of the function might be real, consider
192+
calling `AbstractRule(𝒟, primal, conjugate)` instead, to make use of a more
193+
efficient representation wherever possible.
194+
"""
195+
struct WirtingerRule{P<:AbstractRule,C<:AbstractRule} <: AbstractRule
196+
primal::P
197+
conjugate::C
198+
end
199+
200+
function (rule::WirtingerRule)(args...)
201+
return Wirtinger(rule.primal(args...), rule.conjugate(args...))
202+
end
203+
204+
"""
205+
AbstractRule(𝒟::Type, primal::AbstractRule, conjugate::AbstractRule)
206+
207+
Return a `Rule` evaluating to `primal(Δ) + conjugate(Δ)` if `𝒟 <: Real`,
208+
otherwise return `WirtingerRule(P, C)`.
209+
"""
210+
function AbstractRule(𝒟::Type, primal::AbstractRule, conjugate::AbstractRule)
211+
if 𝒟 <: Real || eltype(𝒟) <: Real
212+
return Rule((args...) -> add(primal(args...), conjugate(args...)))
213+
else
214+
return WirtingerRule(primal, conjugate)
215+
end
216+
end
217+

0 commit comments

Comments
 (0)