Skip to content

Commit ffefa07

Browse files
mcabbottoxinabox
andauthored
Sometimes faster sum(f,x) rule (#529)
* save less stuff in sum(f, xs) rule * version using derivatives_given_input * rules * rm derivatives_given_input * add and fix some tests * rm benchmarks * rebase fixup * fix tests * fix a test * tighter check * tidy, more unicode * rm one unthunk * comment * simplify AbstractZero methods * Apply 4 suggestions Co-authored-by: Frames Catherine White <oxinabox@ucc.asn.au> Co-authored-by: Frames Catherine White <oxinabox@ucc.asn.au>
1 parent c5dbe03 commit ffefa07

File tree

5 files changed

+94
-20
lines changed

5 files changed

+94
-20
lines changed

src/ChainRules.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,9 @@ using Statistics
1515
# to the normal rule of only overload via `ChainRulesCore.rrule`.
1616
import ChainRulesCore: rrule, frule
1717

18+
# Experimental:
19+
using ChainRulesCore: derivatives_given_output
20+
1821
# numbers that we know commute under multiplication
1922
const CommutativeMulNumber = Union{Real,Complex}
2023

src/rulesets/Base/base.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,8 @@ function rrule(::typeof(identity), x)
171171
return (x, identity_pullback)
172172
end
173173

174+
ChainRulesCore.derivatives_given_output(Ω, ::typeof(identity), x) = tuple(tuple(true))
175+
174176
# rouding related,
175177
# we use `zero` rather than `ZeroTangent()` for scalar, and avoids issues with map etc
176178
@scalar_rule round(x) zero(x)

src/rulesets/Base/fastmath_able.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,11 @@ let
8080
return Ω, abs_pullback
8181
end
8282

83+
function ChainRulesCore.derivatives_given_output(Ω, ::typeof(abs), x::Union{Real, Complex})
84+
signx = x isa Real ? sign(x) : x / ifelse(iszero(x), one(Ω), Ω)
85+
return tuple(tuple(signx))
86+
end
87+
8388
## abs2
8489
function frule((_, Δz), ::typeof(abs2), z::Union{Real, Complex})
8590
return abs2(z), 2 * realdot(z, Δz)

src/rulesets/Base/mapreduce.jl

Lines changed: 67 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -72,30 +72,74 @@ function rrule(
7272
end
7373

7474
function rrule(
75-
config::RuleConfig{>:HasReverseMode}, ::typeof(sum), f, xs::AbstractArray; dims=:
76-
)
77-
fx_and_pullbacks = map(x->rrule_via_ad(config, f, x), xs)
78-
y = sum(first, fx_and_pullbacks; dims=dims)
75+
config::RuleConfig{>:HasReverseMode},
76+
::typeof(sum),
77+
f::F,
78+
xs::AbstractArray{T};
79+
dims = :,
80+
) where {F,T}
81+
project = ProjectTo(xs)
7982

80-
pullbacks = last.(fx_and_pullbacks)
83+
if _uses_input_only(f, T)
84+
# Then we can compute the forward pass as usual, save nothing but `xs`:
85+
function sum_pullback_f1(dy)
86+
dxs = broadcast(unthunk(dy), xs) do dyₖ, xᵢ
87+
∂yₖ∂xᵢ = only(only(derivatives_given_output(nothing, f, xᵢ)))
88+
dyₖ * conj(∂yₖ∂xᵢ)
89+
end
90+
return (NoTangent(), NoTangent(), project(dxs))
91+
end
92+
return sum(f, xs; dims), sum_pullback_f1
93+
end
8194

82-
project = ProjectTo(xs)
95+
# (There is an intermediate case, where `derivatives_given_output` needs to
96+
# see `f.(xs)` but we don't need the pullbacks. Not implemented at present.)
97+
98+
# In the general case, we need to save all the pullbacks:
99+
fx_and_pullbacks = map(xᵢ -> rrule_via_ad(config, f, xᵢ), xs)
100+
y = sum(first, fx_and_pullbacks; dims)
101+
102+
function sum_pullback_f2(dy)
103+
# For arrays of arrays, we ought to protect the element against broadcasting:
104+
broadcast_dy = dims isa Colon ? Ref(unthunk(dy)) : unthunk(dy)
105+
if Base.issingletontype(F)
106+
# Then at least `f` has no gradient.
107+
# Broadcasting here gets the shape right with or without `dims` keyword.
108+
dxs = broadcast(fx_and_pullbacks, broadcast_dy) do (_, pbᵢ), dyₖ
109+
unthunk(last(pbᵢ(dyₖ)))
110+
end
111+
return (NoTangent(), NoTangent(), project(dxs))
83112

84-
function sum_pullback(ȳ)
85-
call(f, x) = f(x)
86-
# if dims is :, then need only left-handed only broadcast
87-
broadcast_ȳ = dims isa Colon ? (ȳ,) :
88-
f̄_and_x̄s = call.(pullbacks, broadcast_ȳ)
89-
# no point thunking as most of work is in f̄_and_x̄s which we need to compute for both
90-
= if fieldcount(typeof(f)) === 0 # Then don't need to worry about derivative wrt f
91-
NoTangent()
92113
else
93-
sum(first, f̄_and_x̄s)
114+
# Most general case. If `f` were stateful, we would need to reverse the order
115+
# of iteration here, but since this function makes no guarantee, even the primal
116+
# result is then ill-defined.
117+
df_and_dxs = broadcast(fx_and_pullbacks, broadcast_dy) do (_, pbᵢ), dyₖ
118+
pbᵢ(dyₖ)
119+
end
120+
df = sum(first, df_and_dxs)
121+
dxs = map(unthunk last, df_and_dxs)
122+
return (NoTangent(), df, project(dxs))
94123
end
95-
x̄s = map(unthunk last, f̄_and_x̄s) # project does not support receiving InplaceableThunks
96-
return NoTangent(), f̄, project(x̄s)
97124
end
98-
return y, sum_pullback
125+
return y, sum_pullback_f2
126+
end
127+
128+
"""
129+
_uses_input_only(f, xT::Type)
130+
131+
Returns `true` if it can prove that `derivatives_given_output` will work using only the input
132+
of the given type. Thus there is no need to store the output `y = f(x::xT)`, allowing us to take
133+
a fast path in the `rrule` for `sum(f, xs)`.
134+
135+
Works by seeing if the result of `derivatives_given_output(nothing, f, x)` can be inferred.
136+
The method of `derivatives_given_output` usually comes from `@scalar_rule`.
137+
"""
138+
function _uses_input_only(f::F, ::Type{xT}) where {F,xT}
139+
gT = Core.Compiler._return_type(derivatives_given_output, Tuple{Nothing, F, xT})
140+
# Here we must check `<: Number`, to avoid this, the one rule which can return the `nothing`:
141+
# ChainRules.derivatives_given_output("anything", exp, 1) == (("anything",),)
142+
return isconcretetype(gT) && gT <: Tuple{Tuple{Number}}
99143
end
100144

101145
# https://github.com/JuliaDiff/ChainRules.jl/issues/522
@@ -228,6 +272,7 @@ function ∇prod_dims(vald::Val{dims}, x, dy, y=prod(x; dims=dims)) where {dims}
228272
∇prod_dims!(dx, vald, x, dy, y)
229273
return dx
230274
end
275+
∇prod_dims(::Val, x, dy::AbstractZero, y=0) = dy
231276

232277
function ∇prod_dims!(dx, ::Val{dims}, x, dy, y) where {dims}
233278
iters = ntuple(d -> d in dims ? tuple(:) : axes(x,d), ndims(x)) # Without Val(dims) this is a serious type instability
@@ -244,6 +289,7 @@ function ∇prod(x, dy::Number=1, y::Number=prod(x))
244289
∇prod!(dx, x, dy, y)
245290
return dx
246291
end
292+
∇prod(x, dy::AbstractZero, y::Number=0) = dy
247293

248294
function ∇prod!(dx, x, dy::Number=1, y::Number=prod(x))
249295
numzero = iszero(y) ? count(iszero, x) : 0
@@ -326,7 +372,8 @@ function ∇cumprod_dim(vald::Val{dim}, x::AbstractArray, dy=fill!(zero(x),1), y
326372
dx = fill!(similar(x, T, axes(x)), zero(T))
327373
∇cumprod_dim!(dx, vald, x, dy, y)
328374
return dx
329-
end
375+
end
376+
∇cumprod_dim(vald::Val, x::AbstractArray, dy::AbstractZero, y=0) = dy
330377

331378
@inline function ∇cumprod_dim!(dx::AbstractArray, ::Val{dim}, x::AbstractArray, dy, y) where {dim}
332379
iters = ntuple(k -> k==dim ? Ref(:) : axes(x,k), ndims(x))
@@ -342,6 +389,7 @@ function ∇cumprod(x::AbstractVector, dy=one(x), y=cumprod(x))
342389
∇cumprod!(dx, x, dy, y)
343390
return dx
344391
end
392+
∇cumprod(x::AbstractVector, dy::AbstractZero, y=0) = dy
345393

346394
@inline function ∇cumprod!(dx::AbstractVector, x::AbstractVector, dy, y)
347395
lo, hi = firstindex(x), lastindex(x)

test/rulesets/Base/mapreduce.jl

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,10 +65,12 @@ const CFG = ChainRulesTestUtils.ADviaRuleConfig()
6565
@testset "sum(f, xs)" begin
6666
# This calls back into AD
6767
test_rrule(sum, abs, [-4.0, 2.0, 2.0])
68+
test_rrule(sum, log, rand(3, 4) .+ 1)
6869
test_rrule(sum, cbrt, randn(5))
6970
test_rrule(sum, Multiplier(2.0), [2.0, 4.0, 8.0])
7071

7172
# Complex numbers
73+
test_rrule(sum, log, rand(ComplexF64, 5))
7274
test_rrule(sum, sqrt, rand(ComplexF64, 5))
7375
test_rrule(sum, abs, rand(ComplexF64, 3, 4)) # complex -> real
7476

@@ -82,6 +84,12 @@ const CFG = ChainRulesTestUtils.ADviaRuleConfig()
8284

8385
test_rrule(sum, abs, @SVector[1.0, -3.0])
8486

87+
# Make sure the above test both `derivatives_given_output` path and general case:
88+
@test ChainRules._uses_input_only(abs, Float32)
89+
@test !ChainRules._uses_input_only(cbrt, Float64)
90+
@test ChainRules._uses_input_only(log, ComplexF64)
91+
@test !ChainRules._uses_input_only(abs, ComplexF64)
92+
8593
# covectors
8694
x = [-4.0 2.0; 2.0 -1.0]
8795
test_rrule(sum, inv, x[1, :]')
@@ -102,14 +110,22 @@ const CFG = ChainRulesTestUtils.ADviaRuleConfig()
102110
# ... and Bool produced by function
103111
@test_skip test_rrule(sum, iszero, randn(5)) # DimensionMismatch("second dimension of A, 1, does not match length of x, 0")
104112

105-
106113
# Functions that return a Vector
107114
# see https://github.com/FluxML/Zygote.jl/issues/1074
108115
test_rrule(sum, make_two_vec, [1.0, 3.0, 5.0, 7.0])
109116
test_rrule(sum, make_two_vec, [1.0 2.0; 3.0 4.0])
110117
test_rrule(sum, make_two_vec, [1.0 2.0; 3.0 4.0]; fkwargs=(;dims=2))
111118
test_rrule(sum, make_two_vec, [1.0 2.0; 3.0 4.0]; fkwargs=(;dims=1))
112119
test_rrule(sum, make_two_vec, [1.0 2.0; 3.0 4.0]; fkwargs=(;dims=(3, 4)))
120+
121+
# arrays of arrays, functions which return a scalar:
122+
test_rrule(sum, sum, [[1,2], [3,4], [5,6]]; check_inferred=false)
123+
x2345 = [rand(2,3) for _ in 1:4, _ in 1:5]
124+
test_rrule(sum, prod, x2345; check_inferred=false)
125+
test_rrule(sum, sum, x2345; fkwargs=(;dims=1), check_inferred=false)
126+
test_rrule(sum, sum, x2345; fkwargs=(;dims=(1,2)), check_inferred=false)
127+
128+
test_rrule(sum, cumprod, [[1,2], [3,4], [5,6]]; check_inferred=false)
113129
end
114130

115131
# https://github.com/JuliaDiff/ChainRules.jl/issues/522

0 commit comments

Comments
 (0)