Skip to content

Commit c8d01a3

Browse files
Overhaul Rules (partner PR) (#91)
* [WIP] include dervative WRT self. Scalar functions changed over * wip * WIP * [WIP] make changes to all the rules to return WRT self Temp comment out all accumulate related tests * comment out more accumulate * WIP: * all real scalar rules working * Wirtinger scalars passing * all tests in tests/rulesets/Base/base.jl passing * Fixup Base tests to match frule not returning a tuple * attay test passing * Broadcast fixed * WIP fixing up mapreduce file * make structured and dense rulesets pass * BLAS written but need to re-sort out update rules before done proper * BLAS rules working but update accumulation inplace is diabled * Factorizations working * Make statistics work * remove double extern * fix bad rebase * WIP use InplaceableThunks for updating rules * make factorizations accumulate! right * style and typos Co-Authored-By: Nick Robinson <npr251@gmail.com> * use _fdm rather than making a new central_fdm Co-Authored-By: Nick Robinson <npr251@gmail.com> * set version correctly Co-Authored-By: Nick Robinson <npr251@gmail.com> * name some pullbacks * name more propagators * More named propagators * Name more propagators * More named propagators * delete extra unused _update! methods * name more progators * fix up typos and extern new thunks in tests * test nonsquares * more named propagators * Apply suggestions from code review Co-Authored-By: Nick Robinson <npr251@gmail.com> * more named propagators * Update test/rulesets/Base/base.jl Co-Authored-By: Nick Robinson <npr251@gmail.com>
1 parent d8339cc commit c8d01a3

File tree

21 files changed

+794
-350
lines changed

21 files changed

+794
-350
lines changed

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "ChainRules"
22
uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2"
3-
version = "0.1.1"
3+
version = "0.2.0"
44

55
[deps]
66
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
@@ -10,7 +10,7 @@ Requires = "ae029012-a4dd-5104-9daa-d747884805df"
1010
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
1111

1212
[compat]
13-
ChainRulesCore = "^0.2"
13+
ChainRulesCore = "^0.3"
1414
FiniteDifferences = "^0.7"
1515
julia = "^1.0"
1616

src/helper_functions.jl

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,4 @@
1-
# Special purpose updating for operations which can be done in-place. This function is
2-
# just internal and free-form; it is not a method of `accumulate!` directly as it does
3-
# not adhere to the expected method signature form, i.e. `accumulate!(value, rule, args)`.
4-
# Instead it's `_update!(old, new, extrastuff...)` and is not specific to any particular
5-
# rule.
1+
# Internal helpers for defining the `add!` field of an `InplaceableThunk`
62

73
_update!(x, y) = x + y
84
_update!(x::Array{T,N}, y::AbstractArray{T,N}) where {T,N} = x .+= y
@@ -11,20 +7,22 @@ _update!(x, ::Zero) = x
117
_update!(::Zero, y) = y
128
_update!(::Zero, ::Zero) = Zero()
139

14-
function _update!(x::NamedTuple{Ns}, y::NamedTuple{Ns}) where Ns
15-
return NamedTuple{Ns}(map(p->_update!(getproperty(x, p), getproperty(y, p)), Ns))
16-
end
1710

1811
function _update!(x::NamedTuple, y, p::Symbol)
19-
new = NamedTuple{(p,)}((_update!(getproperty(x, p), y),))
12+
y = extern(y)
13+
yp = getproperty(y, p)
14+
xp = getproperty(x, p)
15+
new_xp = _update!(xp, yp)
16+
new = NamedTuple{(p,)}((new_xp,))
2017
return merge(x, new)
2118
end
2219

23-
function _update!(x::NamedTuple{Ns}, y::NamedTuple{Ns}, p::Symbol) where Ns
24-
return _update!(x, getproperty(y, p), p)
25-
end
26-
20+
"""
21+
_checked_rrule
2722
23+
like `rrule` but throws an error if the `rrule` is not defined.
24+
Rather than returning `nothing`
25+
"""
2826
function _checked_rrule(f, args...; kwargs...)
2927
r = rrule(f, args...; kwargs...)
3028
r isa Nothing && _throw_checked_rrule_error(f, args...; kwargs...)

src/rulesets/Base/array.jl

Lines changed: 41 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -3,56 +3,74 @@
33
#####
44

55
function rrule(::typeof(reshape), A::AbstractArray, dims::Tuple{Vararg{Int}})
6-
return reshape(A, dims), (Rule(Ȳ->reshape(Ȳ, dims)), DNERule())
6+
function reshape_pullback(Ȳ)
7+
return (NO_FIELDS, @thunk(reshape(Ȳ, dims)), DNE())
8+
end
9+
return reshape(A, dims), reshape_pullback
710
end
811

912
function rrule(::typeof(reshape), A::AbstractArray, dims::Int...)
10-
Y, (rule, _) = rrule(reshape, A, dims)
11-
return Y, (rule, fill(DNERule(), length(dims))...)
13+
function reshape_pullback(Ȳ)
14+
∂A = @thunk(reshape(Ȳ, dims))
15+
return (NO_FIELDS, ∂A, fill(DNE(), length(dims))...)
16+
end
17+
return reshape(A, dims...), reshape_pullback
1218
end
1319

1420
#####
1521
##### `hcat` (🐈)
1622
#####
1723

1824
function rrule(::typeof(hcat), A::AbstractArray, Bs::AbstractArray...)
19-
Y = hcat(A, Bs...)
20-
Xs = (A, Bs...)
21-
rules = ntuple(length(Bs) + 1) do i
22-
l = mapreduce(j->size(Xs[j], 2), Base.add_sum, 1:i-1; init=0)
23-
u = l + size(Xs[i], 2)
24-
dim = u > l + 1 ? (l+1:u) : u
25-
# NOTE: The copy here is defensive, since `selectdim` returns a view which we can
26-
# materialize with `copy`
27-
Rule(Ȳ->copy(selectdim(Ȳ, 2, dim)))
25+
function hcat_pullback(Ȳ)
26+
Xs = (A, Bs...)
27+
ntuple(length(Bs) + 2) do full_i
28+
full_i == 1 && return NO_FIELDS
29+
30+
i = full_i - 1
31+
l = mapreduce(j->size(Xs[j], 2), Base.add_sum, 1:i-1; init=0)
32+
u = l + size(Xs[i], 2)
33+
dim = u > l + 1 ? (l+1:u) : u
34+
# NOTE: The copy here is defensive, since `selectdim` returns a view which we can
35+
# materialize with `copy`
36+
copy(selectdim(Ȳ, 2, dim))
37+
end
2838
end
29-
return Y, rules
39+
return hcat(A, Bs...), hcat_pullback
3040
end
3141

3242
#####
3343
##### `vcat`
3444
#####
3545

3646
function rrule(::typeof(vcat), A::AbstractArray, Bs::AbstractArray...)
37-
Y = vcat(A, Bs...)
38-
n = size(A, 1)
39-
∂A = Rule(Ȳ->copy(selectdim(Ȳ, 1, 1:n)))
40-
∂Bs = ntuple(length(Bs)) do i
41-
l = n + mapreduce(j->size(Bs[j], 1), Base.add_sum, 1:i-1; init=0)
42-
u = l + size(Bs[i], 1)
43-
Rule(Ȳ->copy(selectdim(Ȳ, 1, l+1:u)))
47+
function vcat_pullback(Ȳ)
48+
n = size(A, 1)
49+
∂A = copy(selectdim(Ȳ, 1, 1:n))
50+
∂Bs = ntuple(length(Bs)) do i
51+
l = n + mapreduce(j->size(Bs[j], 1), Base.add_sum, 1:i-1; init=0)
52+
u = l + size(Bs[i], 1)
53+
copy(selectdim(Ȳ, 1, l+1:u))
54+
end
55+
return (NO_FIELDS, ∂A, ∂Bs...)
4456
end
45-
return Y, (∂A, Bs...)
57+
return vcat(A, Bs...), vcat_pullback
4658
end
4759

4860
#####
4961
##### `fill`
5062
#####
5163

5264
function rrule(::typeof(fill), value::Any, dims::Tuple{Vararg{Int}})
53-
return fill(value, dims), (Rule(sum), DNERule())
65+
function fill_pullback(Ȳ)
66+
return (NO_FIELDS, @thunk(sum(Ȳ)), DNE())
67+
end
68+
return fill(value, dims), fill_pullback
5469
end
5570

5671
function rrule(::typeof(fill), value::Any, dims::Int...)
57-
return fill(value, dims), (Rule(sum), ntuple(_->DNERule(), length(dims))...)
72+
function fill_pullback(Ȳ)
73+
return (NO_FIELDS, @thunk(sum(Ȳ)), ntuple(_->DNE(), length(dims))...)
74+
end
75+
return fill(value, dims), fill_pullback
5876
end

src/rulesets/Base/base.jl

Lines changed: 27 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -103,10 +103,30 @@
103103

104104
# product rule requires special care for arguments where `mul` is non-commutative
105105

106-
frule(::typeof(*), x::Number, y::Number) = x * y, Rule((Δx, Δy) -> Δx * y + x * Δy)
107-
108-
rrule(::typeof(*), x::Number, y::Number) = x * y, (Rule(ΔΩ -> ΔΩ * y'), Rule(ΔΩ -> x' * ΔΩ))
109-
110-
frule(::typeof(identity), x) = x, Rule(identity)
111-
112-
rrule(::typeof(identity), x) = x, Rule(identity)
106+
function frule(::typeof(*), x::Number, y::Number)
107+
function times_pushforward(_, Δx, Δy)
108+
return Δx * y + x * Δy
109+
end
110+
return x * y, times_pushforward
111+
end
112+
113+
function rrule(::typeof(*), x::Number, y::Number)
114+
function times_pullback(ΔΩ)
115+
return (NO_FIELDS, @thunk(ΔΩ * y'), @thunk(x' * ΔΩ))
116+
end
117+
return x * y, times_pullback
118+
end
119+
120+
function frule(::typeof(identity), x)
121+
function identity_pushforward(_, ẏ)
122+
return
123+
end
124+
return x, identity_pushforward
125+
end
126+
127+
function rrule(::typeof(identity), x)
128+
function identity_pullback(ȳ)
129+
return (NO_FIELDS, ȳ)
130+
end
131+
return x, identity_pullback
132+
end

src/rulesets/Base/broadcast.jl

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,20 +5,26 @@ without relying on inference hacks unless we have something akin to
55
https://github.com/JuliaLang/julia/issues/22129.
66
=#
77
function _cast_diff(f, x)
8-
element_rule = u -> begin
8+
function element_rule(u)
99
fu, du = frule(f, u)
10-
fu, extern(du(One()))
10+
fu, extern(du(NamedTuple(), One()))
1111
end
1212
results = broadcast(element_rule, x)
1313
return first.(results), last.(results)
1414
end
1515

1616
function frule(::typeof(broadcast), f, x)
1717
Ω, ∂x = _cast_diff(f, x)
18-
return Ω, Rule((_, Δx) -> Δx * cast(∂x))
18+
function broadcast_pushforward(_, Δf, Δx)
19+
return Δx * cast(∂x)
20+
end
21+
return Ω, broadcast_pushforward
1922
end
2023

2124
function rrule(::typeof(broadcast), f, x)
2225
values, derivs = _cast_diff(f, x)
23-
return values, (DNERule(), Rule(ΔΩ -> ΔΩ * cast(derivs)))
26+
function broadcast_pullback(ΔΩ)
27+
return (NO_FIELDS, DNE(), @thunk(ΔΩ * cast(derivs)))
28+
end
29+
return values, broadcast_pullback
2430
end

src/rulesets/Base/mapreduce.jl

Lines changed: 44 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,19 @@
44

55
function rrule(::typeof(map), f, xs...)
66
y = map(f, xs...)
7-
∂xs = ntuple(length(xs)) do i
8-
Rule() do
9-
map(ȳ, xs...) do ȳi, xis...
10-
_, ∂xis = _checked_rrule(f, xis...)
11-
extern(∂xis[i](ȳi))
7+
function map_pullback(ȳ)
8+
ntuple(length(xs)+2) do full_i
9+
full_i == 1 && return NO_FIELDS
10+
full_i == 2 && return DNE()
11+
i = full_i-2
12+
@thunk map(ȳ, xs...) do ȳi, xis...
13+
_, pullback = _checked_rrule(f, xis...)
14+
∂xis = pullback(ȳi)
15+
extern(∂xis[i+1]) #+1 to skp ∂self
1216
end
1317
end
1418
end
15-
return y, (DNERule(), ∂xs...)
19+
return y, map_pullback
1620
end
1721

1822
#####
@@ -26,15 +30,18 @@ for mf in (:mapreduce, :mapfoldl, :mapfoldr)
2630
insert!(sig.args, 2, Expr(:parameters, Expr(:kw, :dims, :(:))))
2731
insert!(call.args, 2, Expr(:parameters, Expr(:kw, :dims, :dims)))
2832
end
33+
pullback_name = Symbol(mf, :_pullback)
2934
body = quote
3035
y = $call
31-
∂x = Rule() do
32-
broadcast(x, ȳ) do xi, ȳi
33-
_, ∂xi = _checked_rrule(f, xi)
34-
extern(∂xi(ȳi))
36+
function $pullback_name(ȳ)
37+
∂x = @thunk broadcast(x, ȳ) do xi, ȳi
38+
_, pullback_f = _checked_rrule(f, xi)
39+
_, ∂xi = pullback_f(ȳi)
40+
extern(∂xi)
3541
end
42+
(NO_FIELDS, DNE(), DNE(), ∂x)
3643
end
37-
return y, (DNERule(), DNERule(), ∂x)
44+
return y, $pullback_name
3845
end
3946
eval(Expr(:function, sig, body))
4047
end
@@ -43,22 +50,40 @@ end
4350
##### `sum`
4451
#####
4552

46-
frule(::typeof(sum), x) = (sum(x), Rule(sum))
53+
function frule(::typeof(sum), x)
54+
function sum_pushforward(_, ẋ)
55+
return sum(ẋ)
56+
end
57+
return sum(x), sum_pushforward
58+
end
4759

48-
rrule(::typeof(sum), x) = (sum(x), Rule(cast))
60+
function rrule(::typeof(sum), x)
61+
function sum_pullback(ȳ)
62+
return (NO_FIELDS, cast(ȳ))
63+
end
64+
return sum(x), sum_pullback
65+
end
4966

5067
function rrule(::typeof(sum), f, x::AbstractArray{<:Real}; dims=:)
51-
y, (_, _, ∂x) = rrule(mapreduce, f, Base.add_sum, x; dims=dims)
52-
return y, (DNERule(), ∂x)
68+
y, mr_pullback = rrule(mapreduce, f, Base.add_sum, x; dims=dims)
69+
function sum_pullback(ȳ)
70+
NO_FIELDS, DNE(), last(mr_pullback(ȳ))
71+
end
72+
return y, sum_pullback
5373
end
5474

5575
function rrule(::typeof(sum), x::AbstractArray{<:Real}; dims=:)
56-
y, (_, ∂x) = rrule(sum, identity, x; dims=dims)
57-
return y, ∂x
76+
y, inner_pullback = rrule(sum, identity, x; dims=dims)
77+
function sum_pullback(ȳ)
78+
NO_FIELDS, last(inner_pullback(ȳ))
79+
end
80+
return y, sum_pullback
5881
end
5982

6083
function rrule(::typeof(sum), ::typeof(abs2), x::AbstractArray{<:Real}; dims=:)
6184
y = sum(abs2, x; dims=dims)
62-
∂x = Rule(ȳ -> 2.* x)
63-
return y, (DNERule(), ∂x)
85+
function sum_abs2_pullback(ȳ)
86+
return (NO_FIELDS, DNE(), @thunk(2.* x))
87+
end
88+
return y, sum_abs2_pullback
6489
end

0 commit comments

Comments
 (0)