Skip to content

Commit 7704d22

Browse files
authored
Merge pull request #88 from JuliaDiff/myb/fuse_frule
Fuse `frule` and pushforward
2 parents 0637def + 40467af commit 7704d22

File tree

7 files changed

+25
-27
lines changed

7 files changed

+25
-27
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "ChainRulesCore"
22
uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
3-
version = "0.5.0-DEV"
3+
version = "0.5.0"
44

55
[compat]
66
julia = "^1.0"

src/differentials/composite.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ function backing(x::T)::NamedTuple where T
8686
nfields = fieldcount(T)
8787
names = fieldnames(T)
8888
types = fieldtypes(T)
89-
89+
9090
vals = Expr(:tuple, ntuple(ii->:(getfield(x, $ii)), nfields)...)
9191
return :(NamedTuple{$names, Tuple{$(types...)}}($vals))
9292
else

src/differentials/one.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,4 +11,3 @@ Base.Broadcast.broadcastable(::One) = Ref(One())
1111

1212
Base.iterate(x::One) = (x, nothing)
1313
Base.iterate(::One, ::Any) = nothing
14-

src/differentials/zero.jl

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,5 +11,3 @@ Base.Broadcast.broadcastable(::Zero) = Ref(Zero())
1111

1212
Base.iterate(x::Zero) = (x, nothing)
1313
Base.iterate(::Zero, ::Any) = nothing
14-
15-

src/rule_definition_tools.jl

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -159,19 +159,13 @@ function scalar_frule_expr(f, call, setup_stmts, inputs, partials)
159159
pushforward_returns = pushforward_returns[1]
160160
end
161161

162-
pushforward = quote
162+
return quote
163163
# _ is the input derivative w.r.t. function internals. since we do not
164164
# allow closures/functors with @scalar_rule, it is always ignored
165-
function $(propagator_name(f, :pushforward))(_, $(Δs...))
166-
$pushforward_returns
167-
end
168-
end
169-
170-
return quote
171-
function ChainRulesCore.frule(::typeof($f), $(inputs...))
165+
function ChainRulesCore.frule(::typeof($f), $(inputs...), _, $(Δs...))
172166
$(esc()) = $call
173167
$(setup_stmts...)
174-
return $(esc()), $pushforward
168+
return $(esc()), $pushforward_returns
175169
end
176170
end
177171
end

src/rules.jl

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,34 +2,40 @@
22
##### `frule`/`rrule`
33
#####
44

5+
# TODO: remember to update the examples
56
"""
6-
frule(f, x...)
7+
frule(f, x..., ṡelf, Δx...)
78
8-
Expressing `x` as the tuple `(x₁, x₂, ...)` and the output tuple of `f(x...)`
9-
as `Ω`, return the tuple:
9+
Expressing `x` as the tuple `(x₁, x₂, ...)`, `Δx` as the tuple `(Δx₁, Δx₂,
10+
...)`, and the output tuple of `f(x...)` as `Ω`, return the tuple:
1011
11-
(Ω, (ṡelf, ẋ₁, ẋ₂, ...) -> Ω̇₁, Ω̇₂, ...)
12+
(Ω, (Ω̇₁, Ω̇₂, ...))
1213
1314
The second return value is the propagation rule, or the pushforward.
1415
It takes in differentials corresponding to the inputs (`ẋ₁, ẋ₂, ...`)
1516
and `ṡelf` the internal values of the function (for closures).
1617
1718
18-
If no method matching `frule(f, xs...)` has been defined, then return `nothing`.
19+
If no method matching `frule(f, x..., ṡelf, Δx...)` has been defined, then
20+
return `nothing`.
1921
2022
Examples:
2123
2224
unary input, unary output scalar function:
2325
2426
```
27+
julia> dself = Zero()
28+
Zero()
29+
2530
julia> x = rand();
2631
27-
julia> sinx, sin_pushforward = frule(sin, x);
32+
julia> sinx, sin_pushforward = frule(sin, x, dself, 1)
33+
(0.35696518021277485, 0.9341176907197836)
2834
2935
julia> sinx == sin(x)
3036
true
3137
32-
julia> sin_pushforward(NamedTuple(), 1) == cos(x)
38+
julia> sin_pushforward == cos(x)
3339
true
3440
```
3541
@@ -38,12 +44,12 @@ unary input, binary output scalar function:
3844
```
3945
julia> x = rand();
4046
41-
julia> sincosx, sincos_pushforward = frule(sincos, x);
47+
julia> sincosx, sincos_pushforward = frule(sincos, x, dself, 1);
4248
4349
julia> sincosx == sincos(x)
4450
true
4551
46-
julia> sincos_pushforward(NamedTuple(), 1) == (cos(x), -sin(x))
52+
julia> sincos_pushforward == (cos(x), -sin(x))
4753
true
4854
```
4955

test/rules.jl

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,9 @@ dummy_identity(x) = x
1313
_second(t) = Base.tuple_type_head(Base.tuple_type_tail(t))
1414

1515
@testset "frule and rrule" begin
16-
@test frule(cool, 1) === nothing
17-
@test frule(cool, 1; iscool=true) === nothing
16+
dself = Zero()
17+
@test frule(cool, 1, dself, 1) === nothing
18+
@test frule(cool, 1, dself, 1; iscool=true) === nothing
1819
@test rrule(cool, 1) === nothing
1920
@test rrule(cool, 1; iscool=true) === nothing
2021

@@ -29,9 +30,9 @@ _second(t) = Base.tuple_type_head(Base.tuple_type_tail(t))
2930
Tuple{typeof(rrule),typeof(cool),String}])
3031
@test cool_methods == only_methods
3132

32-
frx, cool_pushforward = frule(cool, 1)
33+
frx, cool_pushforward = frule(cool, 1, dself, 1)
3334
@test frx == 2
34-
@test cool_pushforward(NamedTuple(), 1) == 1
35+
@test cool_pushforward == 1
3536
rrx, cool_pullback = rrule(cool, 1)
3637
self, rr1 = cool_pullback(1)
3738
@test self == NO_FIELDS

0 commit comments

Comments
 (0)