Skip to content

Commit 605354c

Browse files
authored
Add rrules for extrema, findmax, maximum (#480)
* rules for extrema, findmax, maximum * fixup extrema * symmetric maximum rule * promote types by hand * argmax? * allow more zeros * upgrade tests * don't do symmetric convention * tests * fix 1.0 * rm symmetric versions * move extrema to last * tidy * fixup extrema * tests * tests * use eval loop, tidy, tests * forward rules for maximum * frules for findmax * tidy * widen similar to ensure writeability * comments * dispatch -> branch * allow for second derivatives * frule? * update to use CRC 1.3 * better writezero? * fix tests * allow arrays of arrays * version
1 parent a751937 commit 605354c

File tree

4 files changed

+243
-3
lines changed

4 files changed

+243
-3
lines changed

Project.toml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "ChainRules"
22
uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2"
3-
version = "1.13.0"
3+
version = "1.14.0"
44

55
[deps]
66
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
@@ -11,10 +11,10 @@ RealDot = "c1ae055f-0cd5-4b69-90a6-9a35b1a98df9"
1111
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
1212

1313
[compat]
14-
ChainRulesCore = "1.10"
14+
ChainRulesCore = "1.11"
1515
ChainRulesTestUtils = "1"
1616
Compat = "3.35"
17-
FiniteDifferences = "0.12.8"
17+
FiniteDifferences = "0.12.20"
1818
JuliaInterpreter = "0.8"
1919
RealDot = "0.1"
2020
StaticArrays = "1.2"

src/rulesets/Base/array.jl

Lines changed: 162 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -342,3 +342,165 @@ function rrule(::typeof(fill), x::Any, dims...)
342342
fill_pullback(Ȳ) = (NoTangent(), project(sum(Ȳ)), nots...)
343343
return fill(x, dims...), fill_pullback
344344
end
345+
346+
#####
347+
##### `findmax`, `maximum`, etc.
348+
#####
349+
350+
for findm in (:findmin, :findmax)
351+
findm_pullback = Symbol(findm, :_pullback)
352+
353+
@eval function frule((_, xdot), ::typeof($findm), x; dims=:)
354+
y, ind = $findm(x; dims=dims)
355+
return (y, ind), Tangent{typeof((y, ind))}(xdot[ind], NoTangent())
356+
end
357+
358+
@eval function rrule(::typeof($findm), x::AbstractArray; dims=:)
359+
y, ind = $findm(x; dims=dims)
360+
project = ProjectTo(x)
361+
# This pullback is a lot like the one for getindex. Ideally they would probably be combined?
362+
function $findm_pullback((dy, _)) # this accepts e.g. Tangent{Tuple{Float64, Int64}}(4.0, nothing)
363+
dy isa AbstractZero && return (NoTangent(), NoTangent())
364+
x_thunk = @thunk project(_zerolike_writeat(x, unthunk(dy), dims, ind))
365+
x_ithunk = InplaceableThunk(x_thunk) do dx
366+
if dims isa Colon
367+
view(dx, ind) .= view(dx, ind) .+ Ref(unthunk(dy))
368+
else
369+
view(dx, ind) .= view(dx, ind) .+ unthunk(dy) # this could be .+=, but not on Julia 1.0
370+
end
371+
dx
372+
end
373+
return (NoTangent(), x_ithunk)
374+
end
375+
return (y, ind), $findm_pullback
376+
end
377+
end
378+
379+
# This function is roughly `setindex!(zero(x), dy, inds...)`:
380+
381+
function _zerolike_writeat(x::AbstractArray{<:Number}, dy, dims, inds...)
382+
# It's unfortunate to close over `x`, but `similar(typeof(x), axes(x))` doesn't
383+
# allow `eltype(dy)`, nor does it work for many structured matrices.
384+
dx = fill!(similar(x, eltype(dy), axes(x)), 0)
385+
view(dx, inds...) .= dy # possibly 0-dim view, allows dy::Number and dy::Array, and dx::CuArray
386+
dx
387+
end
388+
function _zerolike_writeat(x::AbstractArray, dy, dims, inds...)
389+
# Since we have `x`, we can also handle arrays of arrays.
390+
dx = map(zero, x)
391+
if dims isa Colon
392+
view(dx, inds...) .= Ref(dy)
393+
else
394+
view(dx, inds...) .= dy
395+
end
396+
dx
397+
end
398+
399+
# Allow for second derivatives, by writing rules for `_zerolike_writeat`;
400+
# these rules are the reason it takes a `dims` argument.
401+
402+
function frule((_, _, dydot), ::typeof(_zerolike_writeat), x, dy, dims, inds...)
403+
return _zerolike_writeat(x, dy, dims, inds...), _zerolike_writeat(x, dydot, dims, inds...)
404+
end
405+
406+
function rrule(::typeof(_zerolike_writeat), x, dy, dims, inds...)
407+
z = _zerolike_writeat(x, dy, dims, inds...)
408+
function _zerolike_writeat_pullback(dz)
409+
dx = sum(view(unthunk(dz), inds...); dims=dims)
410+
nots = map(_ -> NoTangent(), inds)
411+
return (NoTangent(), NoTangent(), dx, NoTangent(), nots...)
412+
end
413+
return z, _zerolike_writeat_pullback
414+
end
415+
416+
# These rules for `maximum` pick the same subgradient as `findmax`:
417+
418+
function frule((_, xdot), ::typeof(maximum), x; dims=:)
419+
y, ind = findmax(x; dims=dims)
420+
return y, xdot[ind]
421+
end
422+
423+
function rrule(::typeof(maximum), x::AbstractArray; dims=:)
424+
(y, _), back = rrule(findmax, x; dims=dims)
425+
maximum_pullback(dy) = back((dy, nothing))
426+
return y, maximum_pullback
427+
end
428+
429+
function frule((_, xdot), ::typeof(minimum), x; dims=:)
430+
y, ind = findmin(x; dims=dims)
431+
return y, xdot[ind]
432+
end
433+
434+
function rrule(::typeof(minimum), x::AbstractArray; dims=:)
435+
(y, _), back = rrule(findmin, x; dims=dims)
436+
minimum_pullback(dy) = back((dy, nothing))
437+
return y, minimum_pullback
438+
end
439+
440+
#####
441+
##### `extrema`
442+
#####
443+
444+
function rrule(::typeof(extrema), x::AbstractArray{<:Number}; dims=:)
445+
if dims isa Colon
446+
return _extrema_colon(x)
447+
else
448+
return _extrema_dims(x, dims)
449+
end
450+
end
451+
452+
function _extrema_colon(x)
453+
ylo, ilo = findmin(x)
454+
yhi, ihi = findmax(x)
455+
project = ProjectTo(x)
456+
function extrema_pullback((dylo, dyhi)) # accepts Tangent
457+
if (dylo, dyhi) isa Tuple{AbstractZero, AbstractZero}
458+
return (NoTangent(), NoTangent())
459+
end
460+
# One argument may be AbstractZero here. Use promote_op because
461+
# promote_type allows for * as well as +, hence gives Any.
462+
T = Base.promote_op(+, typeof(dylo), typeof(dyhi))
463+
x_nothunk = let
464+
# x_thunk = @thunk begin # this doesn't infer
465+
dx = fill!(similar(x, T, axes(x)), false)
466+
view(dx, ilo) .= dylo
467+
view(dx, ihi) .= view(dx, ihi) .+ dyhi
468+
project(dx)
469+
end
470+
# x_ithunk = InplaceableThunk(x_thunk) do dx
471+
# view(dx, ilo) .= view(dx, ilo) .+ dylo
472+
# view(dx, ihi) .= view(dx, ihi) .+ dyhi
473+
# dx
474+
# end
475+
return (NoTangent(), x_nothunk)
476+
end
477+
return (ylo, yhi), extrema_pullback
478+
end
479+
480+
function _extrema_dims(x, dims)
481+
ylo, ilo = findmin(x; dims=dims)
482+
yhi, ihi = findmax(x; dims=dims)
483+
y = similar(ylo, Tuple{eltype(ylo), eltype(yhi)})
484+
map!(tuple, y, ylo, yhi) # this is a GPU-friendly version of collect(zip(ylo, yhi))
485+
project = ProjectTo(x)
486+
function extrema_pullback_dims(dy_raw)
487+
dy = unthunk(dy_raw)
488+
@assert dy isa AbstractArray{<:Tuple{Any,Any}}
489+
# Can we actually get Array{Tuple{Float64,ZeroTangent}} here? Not sure.
490+
T = Base.promote_op(+, eltype(dy).parameters...)
491+
x_nothunk = let
492+
# x_thunk = @thunk begin # this doesn't infer
493+
dx = fill!(similar(x, T, axes(x)), false)
494+
view(dx, ilo) .= first.(dy)
495+
view(dx, ihi) .= view(dx, ihi) .+ last.(dy)
496+
project(dx)
497+
end
498+
# x_ithunk = InplaceableThunk(x_thunk) do dx
499+
# view(dx, ilo) .= first.(dy)
500+
# view(dx, ihi) .= view(dx, ihi) .+ last.(dy)
501+
# dx
502+
# end
503+
return (NoTangent(), x_nothunk)
504+
end
505+
return y, extrema_pullback_dims
506+
end

src/rulesets/Base/nondiff.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,8 @@
9696
@non_differentiable all(::Any, ::Any)
9797
@non_differentiable any(::Any)
9898
@non_differentiable any(::Any, ::Any)
99+
@non_differentiable argmax(::Any)
100+
@non_differentiable argmin(::Any)
99101
@non_differentiable ascii(::AbstractString)
100102
@non_differentiable axes(::Any)
101103
@non_differentiable axes(::Any, ::Any)

test/rulesets/Base/array.jl

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -198,3 +198,79 @@ end
198198
test_rrule(fill, 55 + 0.5im, 5)
199199
test_rrule(fill, 3.3, (3, 3, 3))
200200
end
201+
202+
@testset "findmin & findmax" begin
203+
# Forward
204+
test_frule(findmin, rand(10))
205+
test_frule(findmax, rand(10))
206+
@test @inferred(frule((nothing, rand(3,4)), findmin, rand(3,4))) isa Tuple{Tuple{Float64, CartesianIndex}, Tangent}
207+
@test @inferred(frule((nothing, rand(3,4)), findmin, rand(3,4), dims=1)) isa Tuple{Tuple{Matrix, Matrix}, Tangent}
208+
@test_skip test_frule(findmin, rand(3,4)) # error from test_approx(actual::CartesianIndex{2}, expected::CartesianIndex{2}
209+
@test_skip test_frule(findmin, rand(3,4), output_tangent = (rand(), NoTangent()))
210+
@test_skip test_frule(findmin, rand(3,4), fkwargs=(dims=1,))
211+
# These skipped tests might be fixed by https://github.com/JuliaDiff/FiniteDifferences.jl/issues/188
212+
213+
# Reverse
214+
test_rrule(findmin, rand(10), output_tangent = (rand(), false))
215+
test_rrule(findmax, rand(10), output_tangent = (rand(), false))
216+
test_rrule(findmin, rand(5,3))
217+
test_rrule(findmax, rand(5,3))
218+
@test [0 0; 0 5] == @inferred unthunk(rrule(findmax, [1 2; 3 4])[2]((5.0, nothing))[2])
219+
@test [0 0; 0 5] == @inferred unthunk(rrule(findmax, [1 2; 3 4])[2]((5.0, NoTangent()))[2])
220+
221+
# Reverse with dims:
222+
@test [0 0; 5 6] == @inferred unthunk(rrule(findmax, [1 2; 3 4], dims=1)[2](([5 6], nothing))[2])
223+
@test [5 0; 6 0] == @inferred unthunk(rrule(findmin, [1 2; 3 4], dims=2)[2]((hcat([5,6]), nothing))[2])
224+
test_rrule(findmin, rand(3,4), fkwargs=(dims=1,), output_tangent = (rand(1,4), NoTangent()))
225+
test_rrule(findmin, rand(3,4), fkwargs=(dims=2,))
226+
227+
# Second derivatives
228+
test_frule(ChainRules._zerolike_writeat, rand(2, 2), 5.0, :, CartesianIndex(2, 2))
229+
test_rrule(ChainRules._zerolike_writeat, rand(2, 2), 5.0, :, CartesianIndex(2, 2))
230+
@test_skip test_rrule(ChainRules._zerolike_writeat, rand(2, 2), 5.0, 1, [CartesianIndex(2, 1) CartesianIndex(2, 2)] NoTangent()) # MethodError: no method matching isapprox(::Matrix{Float64}, ::Float64; rtol=1.0e-9, atol=1.0e-9)
231+
y, bk = rrule(ChainRules._zerolike_writeat, rand(2, 2), 5.0, 1, [CartesianIndex(2, 1) CartesianIndex(2, 2)])
232+
@test y == [0 0; 5 5]
233+
@test bk([1 2; 3 4]) == (NoTangent(), NoTangent(), [3 4], NoTangent(), NoTangent())
234+
end
235+
236+
@testset "$imum" for imum in [maximum, minimum]
237+
# Forward
238+
test_frule(imum, rand(10))
239+
test_frule(imum, rand(3,4))
240+
test_frule(imum, rand(3,4), fkwargs=(dims=1,))
241+
test_frule(imum, [rand(2) for _ in 1:3])
242+
test_frule(imum, [rand(2) for _ in 1:3, _ in 1:4]; fkwargs=(dims=1,))
243+
244+
# Reverse
245+
test_rrule(imum, rand(10))
246+
test_rrule(imum, rand(3,4))
247+
test_rrule(imum, rand(3,4), fkwargs=(dims=1,))
248+
test_rrule(imum, rand(3,4,5), fkwargs=(dims=(1,3),))
249+
250+
# Arrays of arrays
251+
test_rrule(imum, [rand(2) for _ in 1:3]; check_inferred=false)
252+
test_rrule(imum, [rand(2) for _ in 1:3, _ in 1:4]; fkwargs=(dims=1,), check_inferred=false)
253+
254+
# Case which attains max twice -- can't use FiniteDifferences for this
255+
res = imum == maximum ? [0,1,0,0,0,0] : [1,0,0,0,0,0]
256+
@test res == @inferred unthunk(rrule(imum, [1,2,1,2,1,2])[2](1.0)[2])
257+
258+
# Structured matrix -- NB the minimum is a structral zero here
259+
@test unthunk(rrule(imum, Diagonal(rand(3) .+ 1))[2](5.5)[2]) isa Diagonal
260+
@test unthunk(rrule(imum, UpperTriangular(rand(3,3) .+ 1))[2](5.5)[2]) isa UpperTriangular{Float64}
261+
@test_skip test_rrule(imum, Diagonal(rand(3) .+ 1)) # MethodError: no method matching zero(::Type{Any}), from fill!(A::SparseArrays.SparseMatrixCSC{Any, Int64}, x::Bool)
262+
end
263+
264+
@testset "extrema" begin
265+
test_rrule(extrema, rand(10), output_tangent = (rand(), rand()))
266+
test_rrule(extrema, rand(3,4), fkwargs=(dims=1,), output_tangent = collect(zip(rand(1,4), rand(1,4))))
267+
# Case where both extrema are the same index, to check accumulation:
268+
test_rrule(extrema, rand(1), output_tangent = (rand(), rand()))
269+
test_rrule(extrema, rand(1,1), fkwargs=(dims=2,), output_tangent = hcat((rand(), rand())))
270+
test_rrule(extrema, rand(3,1), fkwargs=(dims=2,), output_tangent = collect(zip(rand(3,1), rand(3,1))))
271+
# Double-check the forward pass
272+
A = randn(3,4,5)
273+
@test extrema(A, dims=(1,3)) == rrule(extrema, A, dims=(1,3))[1]
274+
B = hcat(A[:,:,1], A[:,:,1])
275+
@test extrema(B, dims=2) == rrule(extrema, B, dims=2)[1]
276+
end

0 commit comments

Comments
 (0)