Skip to content

Commit a751937

Browse files
authored
Rules for foldl and accumulate (#526)
* rrule for foldl + tests * rrule for accumulate + tests * rrule for cumsum + tests * rule for sum(::Tuple) * tests + tweaks * rm cumsum * comments * rm comments + old tests * test fixes * skip tuples on 1.0 * version bump * two suggestions, no more pi * tidying * updates to use Tuple ProjectTo, comments, tidying * more * fixes * one more * fixup for 1.0 * fix 1.0, comment * fix 1.6 too? * one more
1 parent edf3a1f commit a751937

File tree

4 files changed

+361
-14
lines changed

4 files changed

+361
-14
lines changed

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "ChainRules"
22
uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2"
3-
version = "1.12.1"
3+
version = "1.13.0"
44

55
[deps]
66
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
@@ -11,7 +11,7 @@ RealDot = "c1ae055f-0cd5-4b69-90a6-9a35b1a98df9"
1111
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
1212

1313
[compat]
14-
ChainRulesCore = "1.1"
14+
ChainRulesCore = "1.10"
1515
ChainRulesTestUtils = "1"
1616
Compat = "3.35"
1717
FiniteDifferences = "0.12.8"

src/rulesets/Base/mapreduce.jl

Lines changed: 152 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,9 @@
22
##### `sum(x)`
33
#####
44

5+
function frule((_, ẋ), ::typeof(sum), x::Tuple)
6+
return sum(x), sum(ẋ)
7+
end
58
function frule((_, ẋ), ::typeof(sum), x; dims=:)
69
return sum(x; dims=dims), sum(ẋ; dims=dims)
710
end
@@ -324,3 +327,152 @@ end
324327
end
325328
return dx
326329
end
330+
331+
#####
332+
##### `foldl`
333+
#####
334+
335+
# `foldl` guarantees to execute `f` in order, left to right. So it makes sense even when
336+
# this `f` is stateful, in which case the gradient must be calculated in the reverse order.
337+
338+
# The implementation aims to be efficient for both tuples and arrays, although using accumulate
339+
# to carry intermediate results along creates arrays of tuples which could be avoided; using a
340+
# loop can be a few times faster. Note also that it does not return a gradient for `init`.
341+
342+
function rrule(
343+
config::RuleConfig{>:HasReverseMode}, ::typeof(foldl), op::G, x::Union{AbstractArray, Tuple};
344+
init=_InitialValue()
345+
) where {G}
346+
list, start = if init === _InitialValue()
347+
_drop1(x), first(x)
348+
else
349+
# Case with init keyword is simpler to understand first!
350+
_reshape1(x, :), init # (vec is for Julia 1.0, accumulate is fussy)
351+
end
352+
hobbits = accumulate(list; init=(start, nothing)) do (a,_), b
353+
# Here `a` is what we would normally cary forward, and `_` ignores
354+
# the previous iteration's pullback function (needed later),
355+
# while `b` is the fresh input from `list` as usual.
356+
c, back = rrule_via_ad(config, op, a, b) # LHS is just documentation here!
357+
# We don't really need to store every `c`, last one is `foldl` output.
358+
# (The name, BTW, is because "there and back again" is the subtitle of Tolkien's book.)
359+
end
360+
y = first(last(hobbits))
361+
axe = axes(x)
362+
project = ProjectTo(x)
363+
function unfoldl(dy)
364+
trio = accumulate(_reverse1(hobbits); init=(0, dy, 0)) do (_, dc, _), (_, back)
365+
ds, da, db = back(dc)
366+
# Don't need to store every `da`, need one for the next iteration + maybe last
367+
end
368+
dop = sum(first, trio)
369+
dx = map(last, _reverse1(trio))
370+
if init === _InitialValue()
371+
# `hobbits` is one short
372+
dx = _vcat1(trio[end][2], dx)
373+
end
374+
return (NoTangent(), dop, project(_reshape1(dx, axe)))
375+
end
376+
return y, unfoldl
377+
end
378+
379+
380+
#####
381+
##### Iterator-or-Tuple functions
382+
#####
383+
384+
# This zoo of underscore functions helps `foldl` & `accumulate` handle both tuples and arrays,
385+
# and also provides some alternatives for versions of Julia where iterators weren't supported.
386+
# Inspired by `Base._reverse`, used in defn of `foldr`.
387+
388+
# To support 2nd derivatives, some may need their own gradient rules. And _drop1 should perhaps
389+
# be replaced by _peel1 like Iterators.peel
390+
391+
if VERSION >= v"1.6"
392+
_reverse1(x) = Iterators.reverse(x)
393+
_drop1(x) = Iterators.drop(x, 1)
394+
_zip2(x, y) = zip(x, y) # for `accumulate`, below
395+
else
396+
# Old versions don't support accumulate(::itr), nor multi-dim reverse
397+
_reverse1(x) = reverse(vec(x))
398+
_drop1(x) = vec(x)[2:end]
399+
_zip2(x, y) = collect(zip(x, y))
400+
end
401+
_reverse1(x::Tuple) = reverse(x)
402+
_drop1(x::Tuple) = Base.tail(x)
403+
_zip2(x::Tuple{Vararg{Any,N}}, y::Tuple{Vararg{Any,N}}) where N = ntuple(i -> (x[i],y[i]), N)
404+
405+
struct _InitialValue end # Old versions don't have `Base._InitialValue`
406+
407+
_vcat1(x, ys::AbstractVector) = vcat(x, ys)
408+
_vcat1(x::AbstractArray, ys::AbstractVector) = vcat([x], ys)
409+
_vcat1(x, ys::Tuple) = (x, ys...)
410+
411+
_reshape1(x::AbstractArray, axe) = reshape(x, axe)
412+
_reshape1(x::Tuple, axe) = x
413+
414+
_no_tuple_tangent(dx::Tangent) = ChainRulesCore.backing(dx)
415+
_no_tuple_tangent(dx) = dx
416+
417+
418+
#####
419+
##### `accumulate`
420+
#####
421+
422+
# Like `foldl` this by definition works in order, so it makes sense to allow stateful `f`.
423+
424+
function rrule(
425+
config::RuleConfig{>:HasReverseMode}, ::typeof(accumulate), op::G, x::Union{AbstractArray, Tuple};
426+
init=_InitialValue(), dims=nothing
427+
) where {G}
428+
isnothing(dims) || dims == 1 && x isa Base.AbstractVecOrTuple || throw(
429+
"accumulate(op, x; dims) is not currently supported by ChainRules, sorry"
430+
# It's not supported by AD either, so no point calling back, and no regression:
431+
# gradient(x -> sum(accumulate(/, x, dims=1)), rand(3,4))
432+
# ERROR: Mutating arrays is not supported
433+
)
434+
list, start = if init === _InitialValue()
435+
_drop1(x), first(x)
436+
else
437+
x, init
438+
end
439+
hobbits = accumulate(list; init = (start, nothing)) do (a, _), b
440+
c, back = rrule_via_ad(config, op, a, b)
441+
end
442+
y = map(first, hobbits)
443+
if init === _InitialValue()
444+
# `hobbits` is one short, and first one doesn't invoke `op`
445+
y = _vcat1(first(x), y)
446+
end
447+
axe = axes(x)
448+
project = ProjectTo(x)
449+
function decumulate(dy)
450+
dy_plain = _no_tuple_tangent(unthunk(dy))
451+
rev_list = if init === _InitialValue()
452+
if VERSION >= v"1.6"
453+
# Here we rely on `zip` to stop early. Begin explicit with _reverse1(_drop1(...))
454+
# gets "no method matching iterate(::Base.Iterators.Reverse{Base.Iterators.Drop{Array{"
455+
_zip2(_reverse1(hobbits), _reverse1(dy_plain))
456+
else
457+
# However, on 1.0 and some others, zip does not stop early. But since accumulate
458+
# also doesn't work on iterators, `_drop1` doesn't make one, so this should work:
459+
_zip2(_reverse1(hobbits), _reverse1(_drop1(dy_plain)))
460+
# What an awful tangle.
461+
end
462+
else
463+
_zip2(_reverse1(hobbits), _reverse1(dy_plain))
464+
end
465+
trio = accumulate(rev_list; init=(0, ZeroTangent(), 0)) do (_, dc, _), ((_, back), dz)
466+
ds, da, db = back(dc + dz)
467+
# Don't need to store every 'da', but need for next iteration, and the last one.
468+
end
469+
dop = sum(first, trio)
470+
dx = map(last, _reverse1(trio))
471+
if init == _InitialValue()
472+
# `hobbits` is one short, and the first one is weird
473+
dx = _vcat1(trio[end][2] + dy_plain[1], dx)
474+
end
475+
return (NoTangent(), dop, project(_reshape1(dx, axe)))
476+
end
477+
return _reshape1(y, axe), decumulate
478+
end

test/rulesets/Base/mapreduce.jl

Lines changed: 125 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,12 @@
22
Base.sum(xs::AbstractArray, weights::AbstractArray) = dot(xs, weights)
33
struct SumRuleConfig <: RuleConfig{Union{HasReverseMode}} end
44

5-
@testset "Maps and Reductions" begin
5+
const CFG = ChainRulesTestUtils.ADviaRuleConfig()
6+
7+
@testset "Reductions" begin
8+
@testset "sum(::Tuple)" begin
9+
test_frule(sum, Tuple(rand(5)))
10+
end
611
@testset "sum(x; dims=$dims)" for dims in (:, 2, (1,3))
712
# Forward
813
test_frule(sum, rand(5); fkwargs=(;dims=dims))
@@ -79,12 +84,11 @@ struct SumRuleConfig <: RuleConfig{Union{HasReverseMode}} end
7984
test_rrule(sum, inv, transpose(view(x, 1, :)))
8085

8186
# Make sure we preserve type for StaticArrays
82-
ADviaRuleConfig = ChainRulesTestUtils.ADviaRuleConfig
83-
_, pb = rrule(ADviaRuleConfig(), sum, abs, @SVector[1.0, -3.0])
87+
_, pb = rrule(CFG, sum, abs, @SVector[1.0, -3.0])
8488
@test pb(1.0) isa Tuple{NoTangent, NoTangent, SVector{2, Float64}}
8589

8690
# make sure we preserve type for Diagonal
87-
_, pb = rrule(ADviaRuleConfig(), sum, abs, Diagonal([1.0, -3.0]))
91+
_, pb = rrule(CFG, sum, abs, Diagonal([1.0, -3.0]))
8892
@test pb(1.0)[3] isa Diagonal
8993

9094
# Boolean -- via @non_differentiable, test that this isn't ambiguous
@@ -173,7 +177,64 @@ struct SumRuleConfig <: RuleConfig{Union{HasReverseMode}} end
173177
@test unthunk(rrule(prod, v)[2](1f0)[2]) == zeros(4)
174178
test_rrule(prod, v)
175179
end
176-
end # prod
180+
end # prod
181+
182+
@testset "foldl(f, ::Array)" begin
183+
# Simple
184+
y1, b1 = rrule(CFG, foldl, *, [1, 2, 3]; init=1)
185+
@test y1 == 6
186+
b1(7) == (NoTangent(), NoTangent(), [42, 21, 14])
187+
188+
y2, b2 = rrule(CFG, foldl, *, [1 2; 0 4]) # without init, needs vcat
189+
@test y2 == 0
190+
b2(8) == (NoTangent(), NoTangent(), [0 0; 64 0]) # matrix, needs reshape
191+
192+
# Test execution order
193+
c5 = Counter()
194+
y5, b5 = rrule(CFG, foldl, c5, [5, 7, 11])
195+
@test c5 == Counter(2)
196+
@test y5 == ((5 + 7)*1 + 11)*2 == foldl(Counter(), [5, 7, 11])
197+
@test b5(1) == (NoTangent(), NoTangent(), [12*32, 12*42, 22])
198+
@test c5 == Counter(42)
199+
200+
c6 = Counter()
201+
y6, b6 = rrule(CFG, foldl, c6, [5, 7, 11], init=3)
202+
@test c6 == Counter(3)
203+
@test y6 == (((3 + 5)*1 + 7)*2 + 11)*3 == foldl(Counter(), [5, 7, 11], init=3)
204+
@test b6(1) == (NoTangent(), NoTangent(), [63*33*13, 43*13, 23])
205+
@test c6 == Counter(63)
206+
207+
# Test gradient of function
208+
y7, b7 = rrule(CFG, foldl, Multiplier(3), [5, 7, 11])
209+
@test y7 == foldl((x,y)->x*y*3, [5, 7, 11])
210+
@test b7(1) == (NoTangent(), Tangent{Multiplier{Int}}(x = 2310,), [693, 495, 315])
211+
212+
y8, b8 = rrule(CFG, foldl, Multiplier(13), [5, 7, 11], init=3)
213+
@test y8 == 2_537_535 == foldl((x,y)->x*y*13, [5, 7, 11], init=3)
214+
@test b8(1) == (NoTangent(), Tangent{Multiplier{Int}}(x = 585585,), [507507, 362505, 230685])
215+
# To find these numbers:
216+
# ForwardDiff.derivative(z -> foldl((x,y)->x*y*z, [5,7,11], init=3), 13)
217+
# ForwardDiff.gradient(z -> foldl((x,y)->x*y*13, z, init=3), [5,7,11]) |> string
218+
219+
# Finite differencing
220+
test_rrule(foldl, /, 1 .+ rand(3,4))
221+
test_rrule(foldl, *, rand(ComplexF64,3,4); fkwargs=(; init=rand(ComplexF64)))
222+
test_rrule(foldl, +, rand(ComplexF64,7); fkwargs=(; init=rand(ComplexF64)))
223+
test_rrule(foldl, max, rand(3); fkwargs=(; init=999))
224+
end
225+
VERSION >= v"1.5" && @testset "foldl(f, ::Tuple)" begin
226+
y1, b1 = rrule(CFG, foldl, *, (1,2,3); init=1)
227+
@test y1 == 6
228+
b1(7) == (NoTangent(), NoTangent(), Tangent{NTuple{3,Int}}(42, 21, 14))
229+
230+
y2, b2 = rrule(CFG, foldl, *, (1, 2, 0, 4))
231+
@test y2 == 0
232+
b2(8) == (NoTangent(), NoTangent(), Tangent{NTuple{4,Int}}(0, 0, 64, 0))
233+
234+
# Finite differencing
235+
test_rrule(foldl, /, Tuple(1 .+ rand(5)))
236+
test_rrule(foldl, *, Tuple(rand(ComplexF64, 5)))
237+
end
177238
end
178239

179240
@testset "Accumulations" begin
@@ -188,14 +249,14 @@ end
188249
@testset "higher dimensions, dims=$dims" for dims in (1,2,3)
189250
m = round.(10 .* randn(4,5), sigdigits=3)
190251
test_rrule(cumprod, m; fkwargs=(;dims=dims), atol=0.1)
191-
m[2,2] = 0
192-
m[2,4] = 0
252+
m[2, 2] = 0
253+
m[2, 4] = 0
193254
test_rrule(cumprod, m; fkwargs=(;dims=dims))
194255

195256
t = round.(10 .* randn(3,3,3), sigdigits=3)
196257
test_rrule(cumprod, t; fkwargs=(;dims=dims))
197-
t[2,2,2] = 0
198-
t[2,3,3] = 0
258+
t[2, 2, 2] = 0
259+
t[2, 3, 3] = 0
199260
test_rrule(cumprod, t; fkwargs=(;dims=dims))
200261
end
201262

@@ -211,5 +272,60 @@ end
211272
back = rrule(cumprod, Diagonal([1, 2]); dims=1)[2]
212273
@test unthunk(back(fill(0.5, 2, 2))[2]) [1/2 0; 0 0] # ProjectTo'd to Diagonal now
213274
end
275+
end # cumprod
276+
277+
@testset "accumulate(f, ::Array)" begin
278+
# Simple
279+
y1, b1 = rrule(CFG, accumulate, *, [1, 2, 3, 4]; init=1)
280+
@test y1 == [1, 2, 6, 24]
281+
@test b1([1, 1, 1, 1]) == (NoTangent(), NoTangent(), [33, 16, 10, 6])
282+
283+
if VERSION >= v"1.5"
284+
y2, b2 = rrule(CFG, accumulate, /, [1 2; 3 4])
285+
@test y2 accumulate(/, [1 2; 3 4])
286+
@test b2(ones(2, 2))[3] [1.5416666 -0.104166664; -0.18055555 -0.010416667] atol=1e-6
287+
end
288+
289+
# Test execution order
290+
c3 = Counter()
291+
y3, b3 = rrule(CFG, accumulate, c3, [5, 7, 11]; init=3)
292+
@test c3 == Counter(3)
293+
@test y3 == [8, 30, 123] == accumulate(Counter(), [5, 7, 11]; init=3)
294+
@test b3([1, 1, 1]) == (NoTangent(), NoTangent(), [29169, 602, 23]) # the 23 is clear!
295+
296+
c4 = Counter()
297+
y4, b4 = rrule(CFG, accumulate, c4, [5, 7, 11])
298+
@test c4 == Counter(2)
299+
@test y4 == [5, (5+7)*1, ((5+7)*1 + 11)*2] == accumulate(Counter(), [5, 7, 11])
300+
@test b4([1, 1, 1]) == (NoTangent(), NoTangent(), [417, 42*(1 + 12), 22])
301+
302+
# Test gradient of function
303+
y7, b7 = rrule(CFG, accumulate, Multiplier(3), [5, 7, 11])
304+
@test y7 == accumulate((x,y)->x*y*3, [5, 7, 11])
305+
@test b7([1, 1, 1]) == (NoTangent(), Tangent{Multiplier{Int}}(x = 2345,), [715, 510, 315])
306+
307+
y8, b8 = rrule(CFG, accumulate, Multiplier(13), [5, 7, 11], init=3)
308+
@test y8 == [195, 17745, 2537535] == accumulate((x,y)->x*y*13, [5, 7, 11], init=3)
309+
@test b8([1, 1, 1]) == (NoTangent(), Tangent{Multiplier{Int}}(x = 588330,), [511095, 365040, 230685])
310+
# To find these numbers:
311+
# ForwardDiff.derivative(z -> sum(accumulate((x,y)->x*y*z, [5,7,11], init=3)), 13)
312+
# ForwardDiff.gradient(z -> sum(accumulate((x,y)->x*y*13, z, init=3)), [5,7,11]) |> string
313+
314+
# Finite differencing
315+
test_rrule(accumulate, *, randn(5); fkwargs=(; init=rand()))
316+
if VERSION >= v"1.5"
317+
test_rrule(accumulate, /, 1 .+ rand(3, 4))
318+
test_rrule(accumulate, ^, 1 .+ rand(2, 3); fkwargs=(; init=rand()))
319+
end
320+
end
321+
VERSION >= v"1.5" && @testset "accumulate(f, ::Tuple)" begin
322+
# Simple
323+
y1, b1 = rrule(CFG, accumulate, *, (1, 2, 3, 4); init=1)
324+
@test y1 == (1, 2, 6, 24)
325+
@test b1((1, 1, 1, 1)) == (NoTangent(), NoTangent(), Tangent{NTuple{4,Int}}(33, 16, 10, 6))
326+
327+
# Finite differencing
328+
test_rrule(accumulate, *, Tuple(randn(5)); fkwargs=(; init=rand()))
329+
test_rrule(accumulate, /, Tuple(1 .+ rand(5)); check_inferred=false)
214330
end
215331
end

0 commit comments

Comments
 (0)