Skip to content

Commit 509a353

Browse files
committed
re-organise split bc, add forward mode
1 parent ec41724 commit 509a353

File tree

3 files changed

+143
-51
lines changed

3 files changed

+143
-51
lines changed

src/rulesets/Base/broadcast.jl

Lines changed: 117 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -18,56 +18,121 @@ _print(args...) = nothing # println(join(args, " ")) #
1818
##### Split broadcasting
1919
#####
2020

21+
# For `z = g.(f.(xs))`, this finds `y = f.(x)` eagerly because the rules for either `f` or `g` may need it,
22+
# and we don't know whether re-computing `y` is cheap.
23+
# (We could check `f` first like `sum(f, x)` does, but checking whether `g` needs `y` is tricky.)
24+
2125
function rrule(cfg::RCR, ::typeof(broadcasted), f::F, args::Vararg{Any,N}) where {F,N}
22-
# = split_bc_rule(cfg, f, args...)
23-
# function split_bc_rule(cfg::RCR, f::F, args::Vararg{Any,N}) where {F,N}
2426
T = Broadcast.combine_eltypes(f, args)
25-
= Core.Compiler._return_type(derivatives_given_output, Tuple{T, F, map(eltype, args)...})
26-
if T === Bool
27+
if T === Bool # TODO use nondifftype here
2728
# 1: Trivial case: non-differentiable output, e.g. `x .> 0`
28-
_print("split_bc_rule 1 ", f)
29-
back_1(_) = ntuple(Returns(ZeroTangent()), length(args)+2)
30-
return f.(args...), back_1
31-
elseif T <: Number && isconcretetype(TΔ)
32-
# 2: Fast path: just broadcast, and use arguments & result to find derivatives.
33-
_print("split_bc_rule 2", f, N)
34-
ys = f.(args...)
35-
function back_2_one(dys) # For f.(x) we do not need StructArrays / unzip at all
36-
delta = broadcast(unthunk(dys), ys, args...) do dy, y, a
37-
das = only(derivatives_given_output(y, f, a))
38-
dy * conj(only(das)) # possibly this * should be made nan-safe.
39-
end
40-
(NoTangent(), NoTangent(), ProjectTo(only(args))(delta))
29+
_print("split_bc_trivial", f)
30+
bc_trivial_back(_) = (NoTangent(), NoTangent(), ntuple(Returns(ZeroTangent()), length(args))...)
31+
return f.(args...), bc_trivial_back
32+
elseif T <: Number && may_bc_derivatives(T, f, args...)
33+
# 2: Fast path: use arguments & result to find derivatives.
34+
return split_bc_derivatives(f, args...)
35+
elseif T <: Number && may_bc_forwards(cfg, f, args...)
36+
# 3: Future path: use `frule_via_ad`?
37+
return split_bc_forwards(cfg, f, args...)
38+
else
39+
# 4: Slow path: collect all the pullbacks & apply them later.
40+
return split_bc_pullbacks(cfg, f, args...)
41+
end
42+
end
43+
44+
# Path 2: This is roughly what `derivatives_given_output` is designed for, should be fast.
45+
46+
function may_bc_derivatives(::Type{T}, f::F, args::Vararg{Any,N}) where {T,F,N}
47+
= Core.Compiler._return_type(derivatives_given_output, Tuple{T, F, map(_eltype, args)...})
48+
return isconcretetype(TΔ)
49+
end
50+
51+
_eltype(x) = eltype(x) # ... but try harder to avoid `eltype(Broadcast.broadcasted(+, [1,2,3], 4.5)) == Any`:
52+
_eltype(bc::Broadcast.Broadcasted) = Broadcast.combine_eltypes(bc.f, bc.args)
53+
54+
function split_bc_derivatives(f::F, arg) where {F}
55+
_print("split_bc_derivative", f)
56+
ys = f.(arg)
57+
function bc_one_back(dys) # For f.(x) we do not need StructArrays / unzip at all
58+
delta = broadcast(unthunk(dys), ys, arg) do dy, y, a
59+
das = only(derivatives_given_output(y, f, a))
60+
dy * conj(only(das)) # possibly this * should be made nan-safe.
4161
end
42-
back_2_one(z::AbstractZero) = (NoTangent(), NoTangent(), z)
43-
function back_2_many(dys)
44-
deltas = tuplecast(unthunk(dys), ys, args...) do dy, y, as...
45-
das = only(derivatives_given_output(y, f, as...))
46-
map(da -> dy * conj(da), das)
47-
end
48-
dargs = map(unbroadcast, args, deltas) # ideally sum in unbroadcast could be part of tuplecast?
49-
(NoTangent(), NoTangent(), dargs...)
62+
return (NoTangent(), NoTangent(), ProjectTo(arg)(delta))
63+
end
64+
bc_one_back(z::AbstractZero) = (NoTangent(), NoTangent(), z)
65+
return ys, bc_one_back
66+
end
67+
function split_bc_derivatives(f::F, args::Vararg{Any,N}) where {F,N}
68+
_print("split_bc_derivatives", f, N)
69+
ys = f.(args...)
70+
function bc_many_back(dys)
71+
deltas = tuplecast(unthunk(dys), ys, args...) do dy, y, as...
72+
das = only(derivatives_given_output(y, f, as...))
73+
map(da -> dy * conj(da), das) # possibly this * should be made nan-safe.
5074
end
51-
back_2_many(z::AbstractZero) = (NoTangent(), NoTangent(), map(Returns(z), args)...)
52-
return ys, N==1 ? back_2_one : back_2_many
53-
else
54-
_print("split_bc_rule 3", f, N)
55-
# 3: Slow path: collect all the pullbacks & apply them later.
56-
# (Since broadcast makes no guarantee about order of calls, and un-fusing
57-
# can change the number of calls, don't bother to try to reverse the iteration.)
58-
ys3, backs = tuplecast(args...) do a...
59-
rrule_via_ad(cfg, f, a...)
75+
dargs = map(unbroadcast, args, deltas) # ideally sum in unbroadcast could be part of tuplecast?
76+
return (NoTangent(), NoTangent(), dargs...)
77+
end
78+
bc_many_back(z::AbstractZero) = (NoTangent(), NoTangent(), map(Returns(z), args)...)
79+
return ys, bc_many_back
80+
end
81+
82+
# Path 3: Use forward mode, or an `frule` if one exists.
83+
# To allow `args...` we need either chunked forward mode, with `adot::Tuple` perhaps:
84+
# https://github.com/JuliaDiff/ChainRulesCore.jl/issues/92
85+
# https://github.com/JuliaDiff/Diffractor.jl/pull/54
86+
# Or else we need to call the `f` multiple times, and maybe that's OK:
87+
# We do know that `f` doesn't have parameters, so maybe it's pure enough,
88+
# and split broadcasting may anyway change N^2 executions into N, e.g. `g.(v ./ f.(v'))`.
89+
# We don't know `f` is cheap, but `split_bc_pullbacks` tends to be very slow.
90+
91+
function may_bc_forwards(cfg::C, f::F, args::Vararg{Any,N}) where {C,F,N}
92+
Base.issingletontype(F) || return false
93+
N==1 || return false # Could weaken this to 1 differentiable
94+
cfg isa RuleConfig{>:HasForwardsMode} && return true # allows frule_via_ad
95+
TA = map(_eltype, args)
96+
TF = Core.Compiler._return_type(frule, Tuple{C, Tuple{NoTangent, TA...}, F, TA...})
97+
return isconcretetype(TF) && TF <: Tuple
98+
end
99+
100+
split_bc_forwards(cfg::RuleConfig{>:HasForwardsMode}, f::F, arg) where {F} = split_bc_inner(frule_via_ad, cfg, f, arg)
101+
split_bc_forwards(cfg::RuleConfig, f::F, arg) where {F} = split_bc_inner(frule, cfg, f, arg)
102+
function split_bc_inner(frule_fun::R, cfg::RuleConfig, f::F, arg) where {R,F}
103+
_print("split_bc_forwards", frule_fun, f)
104+
ys, ydots = tuplecast(arg) do a
105+
frule_fun(cfg, (NoTangent(), one(a)), f, a)
106+
end
107+
function back_forwards(dys)
108+
delta = broadcast(ydots, unthunk(dys), arg) do ydot, dy, a
109+
ProjectTo(a)(conj(ydot) * dy) # possibly this * should be made nan-safe.
60110
end
61-
function back_3(dys)
62-
deltas = tuplecast(backs, unthunk(dys)) do back, dy # could be map, sizes match
63-
map(unthunk, back(dy))
64-
end
65-
dargs = map(unbroadcast, args, Base.tail(deltas))
66-
(NoTangent(), ProjectTo(f)(sum(first(deltas))), dargs...)
111+
return (NoTangent(), NoTangent(), ProjectTo(arg)(delta))
112+
end
113+
back_forwards(z::AbstractZero) = (NoTangent(), NoTangent(), z)
114+
return ys, back_forwards
115+
end
116+
117+
# Path 4: The most generic, save all the pullbacks. Can be 1000x slower.
118+
# Since broadcast makes no guarantee about order of calls, and un-fusing
119+
# can change the number of calls, don't bother to try to reverse the iteration.
120+
121+
function split_bc_pullbacks(cfg::RCR, f::F, args::Vararg{Any,N}) where {F,N}
122+
_print("split_bc_generic", f, N)
123+
ys3, backs = tuplecast(args...) do a...
124+
rrule_via_ad(cfg, f, a...)
125+
end
126+
function back_generic(dys)
127+
deltas = tuplecast(backs, unthunk(dys)) do back, dy # (could be map, sizes match)
128+
map(unthunk, back(dy))
67129
end
68-
back_3(z::AbstractZero) = (NoTangent(), NoTangent(), map(Returns(z), args)...)
69-
return ys3, back_3
130+
dargs = map(unbroadcast, args, Base.tail(deltas))
131+
df = ProjectTo(f)(sum(first(deltas)))
132+
return (NoTangent(), df, dargs...)
70133
end
134+
back_generic(z::AbstractZero) = (NoTangent(), NoTangent(), map(Returns(z), args)...)
135+
return ys3, back_generic
71136
end
72137

73138
# Don't run broadcasting on scalars
@@ -158,8 +223,8 @@ function rrule(::RCR, ::typeof(broadcasted), ::typeof(/), x::NumericOrBroadcast,
158223
dz = unthunk(dz_raw)
159224
dx = @thunk unbroadcast(x, dz ./ conj.(y))
160225
# dy = @thunk -LinearAlgebra.dot(z, dz) / conj(y) # the reason to be eager is to allow dot here
161-
dy = @thunk -sum(Broadcast.instantiate(broadcasted(*, broadcasted(conj, z), dz))) / conj(y) # complete sum is fast?
162-
(NoTangent(), NoTangent(), dx, dy)
226+
dy = @thunk -sum(Broadcast.instantiate(broadcasted(*, broadcasted(conj, z), dz))) / conj(y) # complete sum is fast
227+
return (NoTangent(), NoTangent(), dx, dy)
163228
end
164229
return z, bc_divide_back
165230
end
@@ -234,6 +299,13 @@ function rrule(::RCR, ::typeof(broadcasted), ::typeof(imag), x::AbstractArray{<:
234299
return broadcasted(imag, x), bc_imag_back_2
235300
end
236301

302+
function rrule(::RCR, ::typeof(broadcasted), ::typeof(complex), x::NumericOrBroadcast)
303+
_print("bc complex")
304+
bc_complex_back(dz) = (NoTangent(), NoTangent(), @thunk(unbroadcast(x, unthunk(dz))))
305+
return broadcasted(complex, x), bc_complex_back
306+
end
307+
rrule(::RCR, ::typeof(broadcasted), ::typeof(complex), x::Number) = rrule(complex, x) |> _prepend_zero
308+
237309
#####
238310
##### Shape fixing
239311
#####
@@ -259,7 +331,7 @@ function unbroadcast(x::T, dx) where {T<:Tuple{Vararg{Any,N}}} where {N}
259331
else
260332
sum(dx; dims=2:ndims(dx))
261333
end
262-
ProjectTo(x)(NTuple{length(x)}(val)) # Tangent
334+
return ProjectTo(x)(NTuple{length(x)}(val)) # Tangent
263335
end
264336

265337
unbroadcast(f::Function, df) = sum(df)

test/rulesets/Base/broadcast.jl

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
using Base.Broadcast: broadcasted
22

33
@testset "Broadcasting" begin
4-
@testset "generic 1: trivial path" begin
4+
@testset "split 1: trivial path" begin
55
# test_rrule(copy∘broadcasted, >, rand(3), rand(3)) # MethodError: no method matching eps(::UInt64) inside FiniteDifferences
66
y1, bk1 = rrule(CFG, copybroadcasted, >, rand(3), rand(3))
77
@test y1 isa AbstractArray{Bool}
@@ -12,7 +12,7 @@ using Base.Broadcast: broadcasted
1212
@test all(d -> d isa AbstractZero, bk2(99))
1313
end
1414

15-
@testset "generic 2: fast path" begin
15+
@testset "split 2: derivatives" begin
1616
test_rrule(copybroadcasted, log, rand(3))
1717
test_rrule(copybroadcasted, log, Tuple(rand(3)))
1818

@@ -23,16 +23,22 @@ using Base.Broadcast: broadcasted
2323
test_rrule(copybroadcasted, atan, rand(3), Tuple(rand(1)))
2424
test_rrule(copybroadcasted, atan, Tuple(rand(3)), Tuple(rand(3)))
2525

26-
# Protected by Ref/Tuple:
2726
test_rrule(copybroadcasted, *, rand(3), Ref(rand()))
28-
test_rrule(copybroadcasted, *, rand(3), Ref(rand(2)))
27+
end
28+
29+
@testset "split 3: forwards" begin
30+
test_rrule(copybroadcasted, flog, rand(3))
31+
test_rrule(copybroadcasted, flog, rand(3) .+ im)
32+
# Also, `sin∘cos` may use this path as CFG uses frule_via_ad
2933
end
3034

31-
@testset "generic 3: slow path" begin
35+
@testset "split 4: generic" begin
3236
test_rrule(copybroadcasted, sincos, rand(3), check_inferred=false)
3337
test_rrule(copybroadcasted, sinatan, rand(3), rand(3)', check_inferred=false)
3438
test_rrule(copybroadcasted, sinatan, rand(), rand(3), check_inferred=false)
35-
test_rrule(copybroadcasted, ^, rand(3), 3.0, check_inferred=false)
39+
test_rrule(copybroadcasted, ^, rand(3), 3.0, check_inferred=false) # NoTangent vs. Union{NoTangent, ZeroTangent}
40+
# Many have quite small inference failures, like:
41+
# return type Tuple{NoTangent, NoTangent, Vector{Float64}, Float64} does not match inferred return type Tuple{NoTangent, Union{NoTangent, ZeroTangent}, Vector{Float64}, Float64}
3642

3743
# From test_helpers.jl
3844
test_rrule(copybroadcasted, Multiplier(rand()), rand(3), check_inferred=false)

test/test_helpers.jl

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,14 @@ function ChainRulesCore.rrule(::typeof(make_two_vec), x)
7575
return make_two_vec(x), make_two_vec_pullback
7676
end
7777

78+
"A version of `*` with only an `frule` defined"
79+
fstar(A, B) = A * B
80+
ChainRulesCore.frule((_, ΔA, ΔB), ::typeof(fstar), A, B) = A * B, muladd(ΔA, B, A * ΔB)
81+
82+
"A version of `log` with only an `frule` defined"
83+
flog(x:::Number) = log(x)
84+
ChainRulesCore.frule((_, xdot), ::typeof(flog), x::Number) = log(x), inv(x) * xdot
85+
7886
@testset "test_helpers.jl" begin
7987

8088
@testset "Multiplier" begin
@@ -103,5 +111,11 @@ end
103111
@testset "make_two_vec" begin
104112
test_rrule(make_two_vec, 1.5)
105113
end
114+
115+
@testset "fstar, flog" begin
116+
test_frule(fstar, 1.2, 3.4 + 5im)
117+
test_frule(flog, 6.7)
118+
test_frule(flog, 8.9 + im)
119+
end
106120

107121
end

0 commit comments

Comments
 (0)