Skip to content

Commit 0fabda9

Browse files
committed
=WIP Derivative wrt function
=Make frule wrt self and rrule wrt self different [WIP
1 parent ad1a7a4 commit 0fabda9

File tree

4 files changed

+263
-7
lines changed

4 files changed

+263
-7
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "ChainRulesCore"
22
uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
3-
version = "0.2.1-DEV"
3+
version = "v0.3.0"
44

55
[compat]
66
julia = "^1.0"

src/ChainRulesCore.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ using Base.Broadcast: materialize, materialize!, broadcasted, Broadcasted, broad
44
export AbstractRule, Rule, frule, rrule
55
export @scalar_rule, @thunk
66
export extern, cast, store!, Wirtinger, Zero, One, Casted, DNE, Thunk, DNERule
7+
export NO_FIELDS_RULE, ZERO_RULE
78

89
include("differentials.jl")
910
include("differential_arithmetic.jl")

src/rule_definition_tools.jl

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,15 +14,17 @@ methods for `frule` and `rrule`:
1414
function ChainRulesCore.frule(::typeof(f), x₁::Number, x₂::Number, ...)
1515
Ω = f(x₁, x₂, ...)
1616
\$(statement₁, statement₂, ...)
17-
return Ω, (Rule((Δx₁, Δx₂, ...) -> ∂f₁_∂x₁ * Δx₁ + ∂f₁_∂x₂ * Δx₂ + ...),
17+
return Ω, (ZERO_RULE,
18+
Rule((Δx₁, Δx₂, ...) -> ∂f₁_∂x₁ * Δx₁ + ∂f₁_∂x₂ * Δx₂ + ...),
1819
Rule((Δx₁, Δx₂, ...) -> ∂f₂_∂x₁ * Δx₁ + ∂f₂_∂x₂ * Δx₂ + ...),
1920
...)
2021
end
2122
2223
function ChainRulesCore.rrule(::typeof(f), x₁::Number, x₂::Number, ...)
2324
Ω = f(x₁, x₂, ...)
2425
\$(statement₁, statement₂, ...)
25-
return Ω, (Rule((ΔΩ₁, ΔΩ₂, ...) -> ∂f₁_∂x₁ * ΔΩ₁ + ∂f₂_∂x₁ * ΔΩ₂ + ...),
26+
return Ω, (NO_FIELDS_RULE,
27+
Rule((ΔΩ₁, ΔΩ₂, ...) -> ∂f₁_∂x₁ * ΔΩ₁ + ∂f₂_∂x₁ * ΔΩ₂ + ...),
2628
Rule((ΔΩ₁, ΔΩ₂, ...) -> ∂f₁_∂x₂ * ΔΩ₁ + ∂f₂_∂x₂ * ΔΩ₂ + ...),
2729
...)
2830
end
@@ -34,11 +36,16 @@ Constraints may also be explicitly be provided to override the `Number` constrai
3436
e.g. `f(x₁::Complex, x₂)`, which will constrain `x₁` to `Complex` and `x₂` to
3537
`Number`.
3638
37-
Note that the result of `f(x₁, x₂, ...)` is automatically bound to `Ω`. This
39+
At present this does not support defining rules for closures/functors.
40+
This the first returned rule, representing the derivative with respect to the
41+
function itself, is always the `NO_FIELDS_RULE` (reverse-mode),
42+
or `ZERO_RULE` (forward-mode).
43+
44+
The result of `f(x₁, x₂, ...)` is automatically bound to `Ω`. This
3845
allows the primal result to be conveniently referenced (as `Ω`) within the
3946
derivative/setup expressions.
4047
41-
Note that the `@setup` argument can be elided if no setup code is need. In other
48+
The `@setup` argument can be elided if no setup code is need. In other
4249
words:
4350
4451
@scalar_rule(f(x₁, x₂, ...),
@@ -92,9 +99,18 @@ macro scalar_rule(call, maybe_setup, partials...)
9299
forward_rules = Any[rule_from_partials(inputs[1], partial) for partial in partials]
93100
reverse_rules = Any[rule_from_partials(inputs[1], partials...)]
94101
end
95-
forward_rules = length(forward_rules) == 1 ? forward_rules[1] : Expr(:tuple, forward_rules...)
96-
reverse_rules = length(reverse_rules) == 1 ? reverse_rules[1] : Expr(:tuple, reverse_rules...)
102+
103+
# First pseudo-partial is derivative WRT function itself. Since this macro does not
104+
# support closures, it is just the empty NamedTuple
105+
forward_rules = Expr(:tuple, ZERO_RULE, forward_rules...)
106+
reverse_rules = Expr(:tuple, NO_FIELDS_RULE, reverse_rules...)
97107
return quote
108+
if fieldcount(typeof($f)) > 0
109+
throw(ArgumentError(
110+
"@scalar_rule cannot be used on closures/functors (such as $f)"
111+
))
112+
end
113+
98114
function ChainRulesCore.frule(::typeof($f), $(inputs...))
99115
$(esc()) = $call
100116
$(setup_stmts...)

src/rules.jl

Lines changed: 239 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,242 @@
1+
"""
2+
Subtypes of `AbstractRule` are types which represent the primitive derivative
3+
propagation "rules" that can be composed to implement forward- and reverse-mode
4+
automatic differentiation.
5+
6+
More specifically, a `rule::AbstractRule` is a callable Julia object generally
7+
obtained via calling [`frule`](@ref) or [`rrule`](@ref). Such rules accept
8+
differential values as input, evaluate the chain rule using internally stored/
9+
computed partial derivatives to produce a single differential value, then
10+
return that calculated differential value.
11+
12+
For example:
13+
14+
```jldoctest
15+
julia> using ChainRulesCore: frule, rrule, AbstractRule
16+
17+
julia> x, y = rand(2);
18+
19+
julia> h, dh = frule(hypot, x, y);
20+
21+
julia> h == hypot(x, y)
22+
true
23+
24+
julia> isa(dh, AbstractRule)
25+
true
26+
27+
julia> Δx, Δy = rand(2);
28+
29+
julia> dh(Δx, Δy) == ((x / h) * Δx + (y / h) * Δy)
30+
true
31+
32+
julia> h, (dx, dy) = rrule(hypot, x, y);
33+
34+
julia> h == hypot(x, y)
35+
true
36+
37+
julia> isa(dx, AbstractRule) && isa(dy, AbstractRule)
38+
true
39+
40+
julia> Δh = rand();
41+
42+
julia> dx(Δh) == (x / h) * Δh
43+
true
44+
45+
julia> dy(Δh) == (y / h) * Δh
46+
true
47+
```
48+
49+
See also: [`frule`](@ref), [`rrule`](@ref), [`Rule`](@ref), [`DNERule`](@ref), [`WirtingerRule`](@ref)
50+
"""
51+
abstract type AbstractRule end
52+
53+
# this ensures that consumers don't have to special-case rule destructuring
54+
Base.iterate(rule::AbstractRule) = (rule, nothing)
55+
Base.iterate(::AbstractRule, ::Any) = nothing
56+
57+
# This ensures we don't need to check whether the result of `rrule`/`frule` is a tuple
58+
# in order to get the `i`th rule (assuming it's 1)
59+
Base.getindex(rule::AbstractRule, i::Integer) = i == 1 ? rule : throw(BoundsError())
60+
61+
"""
62+
accumulate(Δ, rule::AbstractRule, args...)
63+
64+
Return `Δ + rule(args...)` evaluated in a manner that supports ChainRulesCore'
65+
various `AbstractDifferential` types.
66+
67+
This method intended to be customizable for specific rules/input types. For
68+
example, here is pseudocode to overload `accumulate` w.r.t. a specific forward
69+
differentiation rule for a given function `f`:
70+
71+
```
72+
df(x) = # forward differentiation primitive implementation
73+
74+
frule(::typeof(f), x) = (f(x), Rule(df))
75+
76+
accumulate(Δ, rule::Rule{typeof(df)}, x) = # customized `accumulate` implementation
77+
```
78+
79+
See also: [`accumulate!`](@ref), [`store!`](@ref), [`AbstractRule`](@ref)
80+
"""
81+
accumulate(Δ, rule::AbstractRule, args...) = add(Δ, rule(args...))
82+
83+
"""
84+
accumulate!(Δ, rule::AbstractRule, args...)
85+
86+
Similar to [`accumulate`](@ref), but compute `Δ + rule(args...)` in-place,
87+
storing the result in `Δ`.
88+
89+
Note that this function internally calls `Base.Broadcast.materialize!(Δ, ...)`.
90+
91+
See also: [`accumulate`](@ref), [`store!`](@ref), [`AbstractRule`](@ref)
92+
"""
93+
function accumulate!(Δ, rule::AbstractRule, args...)
94+
return materialize!(Δ, broadcastable(add(cast(Δ), rule(args...))))
95+
end
96+
97+
accumulate!::Number, rule::AbstractRule, args...) = accumulate(Δ, rule, args...)
98+
99+
"""
100+
store!(Δ, rule::AbstractRule, args...)
101+
102+
Compute `rule(args...)` and store the result in `Δ`, potentially avoiding
103+
intermediate temporary allocations that might be necessary for alternative
104+
approaches (e.g. `copyto!(Δ, extern(rule(args...)))`)
105+
106+
Note that this function internally calls `Base.Broadcast.materialize!(Δ, ...)`.
107+
108+
Like [`accumulate`](@ref) and [`accumulate!`](@ref), this function is intended
109+
to be customizable for specific rules/input types.
110+
111+
See also: [`accumulate`](@ref), [`accumulate!`](@ref), [`AbstractRule`](@ref)
112+
"""
113+
store!(Δ, rule::AbstractRule, args...) = materialize!(Δ, broadcastable(rule(args...)))
114+
115+
#####
116+
##### `Rule`
117+
#####
118+
119+
Cassette.@context RuleContext
120+
121+
const RULE_CONTEXT = Cassette.disablehooks(RuleContext())
122+
123+
Cassette.overdub(::RuleContext, ::typeof(+), a, b) = add(a, b)
124+
Cassette.overdub(::RuleContext, ::typeof(*), a, b) = mul(a, b)
125+
126+
Cassette.overdub(::RuleContext, ::typeof(add), a, b) = add(a, b)
127+
Cassette.overdub(::RuleContext, ::typeof(mul), a, b) = mul(a, b)
128+
129+
"""
130+
Rule(propation_function[, updating_function])
131+
132+
Return a `Rule` that wraps the given `propation_function`. It is assumed that
133+
`propation_function` is a callable object whose arguments are differential
134+
values, and whose output is a single differential value calculated by applying
135+
internally stored/computed partial derivatives to the input differential
136+
values.
137+
138+
If an updating function is provided, it is assumed to have the signature `u(Δ, xs...)`
139+
and to store the result of the propagation function applied to the arguments `xs` into
140+
`Δ` in-place, returning `Δ`.
141+
142+
For example:
143+
144+
```
145+
frule(::typeof(*), x, y) = x * y, Rule((Δx, Δy) -> Δx * y + x * Δy)
146+
147+
rrule(::typeof(*), x, y) = x * y, (Rule(ΔΩ -> ΔΩ * y'), Rule(ΔΩ -> x' * ΔΩ))
148+
```
149+
150+
See also: [`frule`](@ref), [`rrule`](@ref), [`accumulate`](@ref), [`accumulate!`](@ref), [`store!`](@ref)
151+
"""
152+
struct Rule{F,U<:Union{Function,Nothing}} <: AbstractRule
153+
f::F
154+
u::U
155+
end
156+
157+
# NOTE: Using `Core.Typeof` instead of `typeof` here so that if we define a rule for some
158+
# constructor based on a `UnionAll`, we get `Rule{Type{Thing}}` instead of `Rule{UnionAll}`
159+
Rule(f) = Rule{Core.Typeof(f),Nothing}(f, nothing)
160+
161+
(rule::Rule{F})(args...) where {F} = Cassette.overdub(RULE_CONTEXT, rule.f, args...)
162+
163+
# Specialized accumulation
164+
# TODO: Does this need to be overdubbed in the rule context?
165+
accumulate!(Δ, rule::Rule{F,U}, args...) where {F,U<:Function} = rule.u(Δ, args...)
166+
167+
168+
"""
169+
NO_FIELDS_RULE
170+
171+
Constant for the rule for the derivative with respect to structure that has no fields.
172+
The most notable use for this is for the reverse-mode derivative with respect to the
173+
function itself, when that function is not a closure.
174+
The rule returns an empty `NamedTuple` for all inputs.
175+
"""
176+
const NO_FIELDS_RULE = Rule((args...)->NamedTuple())
177+
178+
"""
179+
ZERO_RULE
180+
181+
This is a rule that returns `Zero()` regardless of input.
182+
The most notable use for this is for the forward-mode derivative with respect to the
183+
function itself, when that function is not a closure.
184+
"""
185+
const ZERO_RULE = Rule((args...)->Zero())
186+
187+
188+
189+
#####
190+
##### `DNERule`
191+
#####
192+
193+
"""
194+
DNERule(args...)
195+
196+
Construct a `DNERule` object, which is an `AbstractRule` that signifies that the
197+
current function is not differentiable with respect to a particular parameter.
198+
**DNE** is an abbreviation for Does Not Exist.
199+
"""
200+
struct DNERule <: AbstractRule end
201+
202+
DNERule(args...) = DNE()
203+
204+
#####
205+
##### `WirtingerRule`
206+
#####
207+
208+
"""
209+
WirtingerRule(primal::AbstractRule, conjugate::AbstractRule)
210+
211+
Construct a `WirtingerRule` object, which is an `AbstractRule` that consists of
212+
an `AbstractRule` for both the primal derivative ``∂/∂x`` and the conjugate
213+
derivative ``∂/∂x̅``. If the domain `𝒟` of the function might be real, consider
214+
calling `AbstractRule(𝒟, primal, conjugate)` instead, to make use of a more
215+
efficient representation wherever possible.
216+
"""
217+
struct WirtingerRule{P<:AbstractRule,C<:AbstractRule} <: AbstractRule
218+
primal::P
219+
conjugate::C
220+
end
221+
222+
function (rule::WirtingerRule)(args...)
223+
return Wirtinger(rule.primal(args...), rule.conjugate(args...))
224+
end
225+
226+
"""
227+
AbstractRule(𝒟::Type, primal::AbstractRule, conjugate::AbstractRule)
228+
229+
Return a `Rule` evaluating to `primal(Δ) + conjugate(Δ)` if `𝒟 <: Real`,
230+
otherwise return `WirtingerRule(P, C)`.
231+
"""
232+
function AbstractRule(𝒟::Type, primal::AbstractRule, conjugate::AbstractRule)
233+
if 𝒟 <: Real || eltype(𝒟) <: Real
234+
return Rule((args...) -> add(primal(args...), conjugate(args...)))
235+
else
236+
return WirtingerRule(primal, conjugate)
237+
end
238+
end
239+
1240
#####
2241
##### `frule`/`rrule`
3242
#####

0 commit comments

Comments
 (0)