Skip to content

Commit 2ef3f20

Browse files
Sort out frule API (#129)
* Change frule implementation * Add API regression tests * Core._apply * Remove unnecessary tail * Fix inference issue * Some tweaks * Update test/rules.jl Co-Authored-By: Lyndon White <oxinabox@ucc.asn.au> * tweak style * Tweak docs further * Update Project.toml * Require tests pass on 1.3 * Update .travis.yml Co-Authored-By: Lyndon White <oxinabox@ucc.asn.au> * Style Co-authored-by: Lyndon White <oxinabox@ucc.asn.au>
1 parent 30f9852 commit 2ef3f20

File tree

6 files changed

+100
-39
lines changed

6 files changed

+100
-39
lines changed

.travis.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,10 @@ julia:
77
- 1.1
88
- 1.2
99
- 1.3
10+
- 1.4
1011
- nightly
1112
jobs:
1213
allow_failures:
13-
- julia: 1.3
1414
- julia: nightly
1515
include:
1616
- stage: "Documentation"

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "ChainRulesCore"
22
uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
3-
version = "0.6.1"
3+
version = "0.7.0"
44

55
[deps]
66
MuladdMacro = "46d2c3a1-f734-5fdb-9937-b9b9aeba4221"

docs/src/index.md

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ Almost always the _pullback_ will be declared locally within the `rrule`, and wi
5959

6060
The `frule` is written:
6161
```julia
62-
function frule(::typeof(foo), args..., Δself, Δargs...; kwargs...)
62+
function frule((Δself, Δargs...), ::typeof(foo), args...; kwargs...)
6363
...
6464
return y, ∂Y
6565
end
@@ -175,15 +175,15 @@ end
175175
```
176176
But because it is fused into frule we see it as part of:
177177
```julia
178-
function frule(::typeof(foo), args..., Δself, Δargs...; kwargs...)
178+
function frule((Δself, Δargs...), ::typeof(foo), args...; kwargs...)
179179
...
180180
return y, ∂y
181181
end
182182
```
183183

184184

185185
The input to the pushforward is often called the _perturbation_.
186-
If the function is `y = f(x)` often the pushforward will be written `ẏ = last(frule(f, x, ṡelf, ẋ))`.
186+
If the function is `y = f(x)` often the pushforward will be written `ẏ = last(frule((ṡelf, ẋ), f, x))`.
187187
`` is commonly used to represent the perturbation for `y`.
188188

189189
!!! note
@@ -238,14 +238,14 @@ If we would like to know the the directional derivative of `f` for an input chan
238238

239239
```julia
240240
direction = (1.5, 0.4, -1) # (ȧ, ḃ, ċ)
241-
y, ẏ = frule(f, a, b, c, Zero(), direction)
241+
y, ẏ = frule((Zero(), direction...), f, a, b, c)
242242
```
243243

244244
On the basis directions one gets the partial derivatives of `y`:
245245
```julia
246-
y, ∂y_∂a = frule(f, a, b, c, Zero(), 1, 0, 0)
247-
y, ∂y_∂b = frule(f, a, b, c, Zero(), 0, 1, 0)
248-
y, ∂y_∂c = frule(f, a, b, c, Zero(), 0, 0, 1)
246+
y, ∂y_∂a = frule((Zero(), 1, 0, 0), f, a, b, c)
247+
y, ∂y_∂b = frule((Zero(), 0, 1, 0), f, a, b, c)
248+
y, ∂y_∂c = frule((Zero(), 0, 0, 1), f, a, b, c)
249249
```
250250

251251
Similarly, the most trivial use of `rrule` and returned `pullback` is to calculate the [Gradient](https://en.wikipedia.org/wiki/Gradient):
@@ -320,10 +320,10 @@ x = 3;
320320
ẋ = 1; # ∂x/∂x
321321
nofields = Zero(); # ∂self/∂self
322322
323-
a, ȧ = frule(sin, x, nofields, ẋ); # ∂a/∂x
324-
b, ḃ = frule(*, 2, a, nofields, Zero(), unthunk(ȧ)); # ∂b/∂x = ∂b/∂a⋅∂a/∂x
323+
a, ȧ = frule((nofields, ẋ), sin, x); # ∂a/∂x
324+
b, ḃ = frule((nofields, Zero(), unthunk(ȧ)), *, 2, a); # ∂b/∂x = ∂b/∂a⋅∂a/∂x
325325
326-
c, ċ = frule(asin, b, nofields, unthunk(ḃ)); # ∂c/∂x = ∂c/∂b⋅∂b/∂x = ∂f/∂x
326+
c, ċ = frule((nofields, unthunk(ḃ)), asin, b); # ∂c/∂x = ∂c/∂b⋅∂b/∂x = ∂f/∂x
327327
unthunk(ċ)
328328
# output
329329
-2.0638950738662625

src/rule_definition_tools.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,7 @@ function scalar_frule_expr(f, call, setup_stmts, inputs, partials)
148148

149149
# Δs is the input to the propagator rule
150150
# because this is push-forward there is one per input to the function
151-
Δs = [Symbol(string(, i)) for i in 1:n_inputs]
151+
Δs = [esc(Symbol(, i)) for i in 1:n_inputs]
152152
pushforward_returns = map(1:n_outputs) do output_i
153153
∂s = partials[output_i].args
154154
propagation_expr(Δs, ∂s)
@@ -163,7 +163,7 @@ function scalar_frule_expr(f, call, setup_stmts, inputs, partials)
163163
return quote
164164
# _ is the input derivative w.r.t. function internals. since we do not
165165
# allow closures/functors with @scalar_rule, it is always ignored
166-
function ChainRulesCore.frule(::typeof($f), $(inputs...), _, $(Δs...))
166+
function ChainRulesCore.frule((_, $(Δs...)), ::typeof($f), $(inputs...))
167167
$(esc()) = $call
168168
$(setup_stmts...)
169169
return $(esc()), $pushforward_returns

src/rules.jl

Lines changed: 11 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -2,54 +2,48 @@
22
##### `frule`/`rrule`
33
#####
44

5-
# TODO: remember to update the examples
65
"""
7-
frule(f, x..., ṡelf, Δx...)
6+
frule((Δf, Δx...), f, x...)
87
9-
Expressing `x` as the tuple `(x₁, x₂, ...)`, `Δx` as the tuple `(Δx₁, Δx₂,
10-
...)`, and the output tuple of `f(x...)` as `Ω`, return the tuple:
8+
Expressing the output of `f(x...)` as `Ω`, return the tuple:
119
12-
(Ω, (Ω̇₁, Ω̇₂, ...))
10+
(Ω, ΔΩ)
1311
14-
The second return value is the propagation rule, or the pushforward.
15-
It takes in differentials corresponding to the inputs (`ẋ₁, ẋ₂, ...`)
16-
and `ṡelf` the internal values of the function (for closures).
12+
The second return value is the differential w.r.t. the output.
1713
18-
19-
If no method matching `frule(f, x..., ṡelf, Δx...)` has been defined, then
20-
return `nothing`.
14+
If no method matching `frule((Δf, Δx...), f, x...)` has been defined, then return `nothing`.
2115
2216
Examples:
2317
2418
unary input, unary output scalar function:
2519
26-
```
20+
```jldoctest
2721
julia> dself = Zero()
2822
Zero()
2923
3024
julia> x = rand();
3125
32-
julia> sinx, sin_pushforward = frule(sin, x, dself, 1)
26+
julia> sinx, Δsinx = frule(sin, x, dself, 1)
3327
(0.35696518021277485, 0.9341176907197836)
3428
3529
julia> sinx == sin(x)
3630
true
3731
38-
julia> sin_pushforward == cos(x)
32+
julia> Δsinx == cos(x)
3933
true
4034
```
4135
4236
unary input, binary output scalar function:
4337
44-
```
38+
```jldoctest
4539
julia> x = rand();
4640
47-
julia> sincosx, sincos_pushforward = frule(sincos, x, dself, 1);
41+
julia> sincosx, Δsincosx = frule(sincos, x, dself, 1);
4842
4943
julia> sincosx == sincos(x)
5044
true
5145
52-
julia> sincos_pushforward == (cos(x), -sin(x))
46+
julia> Δsincosx == (cos(x), -sin(x))
5347
true
5448
```
5549

test/rules.jl

Lines changed: 75 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,56 @@ nice(x) = 1
1515
very_nice(x, y) = x + y
1616
@scalar_rule(very_nice(x, y), (One(), One()))
1717

18+
19+
# Tests that aim to ensure that the API for frules doesn't regress and make these things
20+
# hard to implement.
21+
22+
varargs_function(x...) = sum(x)
23+
function ChainRulesCore.frule(dargs, ::typeof(varargs_function), x...)
24+
Δx = Base.tail(dargs)
25+
return sum(x), sum(Δx)
26+
end
27+
28+
mixed_vararg(x, y, z...) = x + y + sum(z)
29+
function ChainRulesCore.frule(
30+
dargs::Tuple{Any, Any, Any, Vararg},
31+
::typeof(mixed_vararg), x, y, z...,
32+
)
33+
Δx = dargs[2]
34+
Δy = dargs[3]
35+
Δz = dargs[4:end]
36+
return mixed_vararg(x, y, z...), Δx + Δy + sum(Δz)
37+
end
38+
39+
type_constraints(x::Int, y::Float64) = x + y
40+
function ChainRulesCore.frule(
41+
(_, Δx, Δy)::Tuple{Any, Int, Float64},
42+
::typeof(type_constraints), x::Int, y::Float64,
43+
)
44+
return type_constraints(x, y), Δx + Δy
45+
end
46+
47+
mixed_vararg_type_constaint(x::Float64, y::Real, z::Vararg{Float64}) = x + y + sum(z)
48+
function ChainRulesCore.frule(
49+
dargs::Tuple{Any, Float64, Real, Vararg{Float64}},
50+
::typeof(mixed_vararg_type_constaint), x::Float64, y::Real, z::Vararg{Float64},
51+
)
52+
Δx = dargs[2]
53+
Δy = dargs[3]
54+
Δz = dargs[4:end]
55+
return x + y + sum(z), Δx + Δy + sum(Δz)
56+
end
57+
58+
ChainRulesCore.frule(dargs, ::typeof(Core._apply), f, x...) = frule(dargs[2:end], f, x...)
59+
1860
#######
1961

2062
_second(t) = Base.tuple_type_head(Base.tuple_type_tail(t))
2163

2264
@testset "frule and rrule" begin
2365
dself = Zero()
24-
@test frule(cool, 1, dself, 1) === nothing
25-
@test frule(cool, 1, dself, 1; iscool=true) === nothing
66+
@test frule((dself, 1), cool, 1) === nothing
67+
@test frule((dself, 1), cool, 1; iscool=true) === nothing
2668
@test rrule(cool, 1) === nothing
2769
@test rrule(cool, 1; iscool=true) === nothing
2870

@@ -37,7 +79,7 @@ _second(t) = Base.tuple_type_head(Base.tuple_type_tail(t))
3779
Tuple{typeof(rrule),typeof(cool),String}])
3880
@test cool_methods == only_methods
3981

40-
frx, cool_pushforward = frule(cool, 1, dself, 1)
82+
frx, cool_pushforward = frule((dself, 1), cool, 1)
4183
@test frx === 2
4284
@test cool_pushforward === 1
4385
rrx, cool_pullback = rrule(cool, 1)
@@ -46,13 +88,38 @@ _second(t) = Base.tuple_type_head(Base.tuple_type_tail(t))
4688
@test rrx === 2
4789
@test rr1 === 1
4890

49-
frx, nice_pushforward = frule(nice, 1, dself, 1)
91+
frx, nice_pushforward = frule((dself, 1), nice, 1)
5092
@test nice_pushforward === Zero()
5193
rrx, nice_pullback = rrule(nice, 1)
5294
@test (NO_FIELDS, Zero()) === nice_pullback(1)
5395

54-
sx = @SVector [1, 2]
55-
sy = @SVector [3, 4]
56-
# This is testing that @scalar_rule and `One()` play nice together, w.r.t broadcasting
57-
@inferred frule(very_nice, 1, 2, Zero(), sx, sy)
96+
97+
# Test that these run. Do not care about numerical correctness.
98+
@test frule((nothing, 1.0, 1.0, 1.0), varargs_function, 0.5, 0.5, 0.5) == (1.5, 3.0)
99+
100+
@test frule((nothing, 1.0, 2.0, 3.0, 4.0), mixed_vararg, 1.0, 2.0, 3.0, 4.0) == (10.0, 10.0)
101+
102+
@test frule((nothing, 3, 2.0), type_constraints, 5, 4.0) == (9.0, 5.0)
103+
@test frule((nothing, 3.0, 2.0im), type_constraints, 5, 4.0) == nothing
104+
105+
@test(frule(
106+
(nothing, 3.0, 2.0, 1.0, 0.0),
107+
mixed_vararg_type_constaint, 3.0, 2.0, 1.0, 0.0,
108+
) == (6.0, 6.0))
109+
110+
# violates type constraints, thus an frule should not be found.
111+
@test frule(
112+
(nothing, 3, 2.0, 1.0, 5.0),
113+
mixed_vararg_type_constaint, 3, 2.0, 1.0, 0,
114+
) == nothing
115+
116+
@test frule((nothing, nothing, 5.0), Core._apply, dummy_identity, 4.0) == (4.0, 5.0)
117+
118+
@testset "broadcasting One" begin
119+
sx = @SVector [1, 2]
120+
sy = @SVector [3, 4]
121+
122+
# Test that @scalar_rule and `One()` play nice together, w.r.t broadcasting
123+
@inferred frule((Zero(), sx, sy), very_nice, 1, 2)
124+
end
58125
end

0 commit comments

Comments
 (0)