Skip to content

Commit 7e7d105

Browse files
committed
change generic rule to use BroadcastStyle
1 parent 2583c8c commit 7e7d105

File tree

2 files changed

+117
-95
lines changed

2 files changed

+117
-95
lines changed

src/rulesets/Base/broadcast.jl

Lines changed: 63 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
1-
using Base.Broadcast: Broadcast, broadcasted, Broadcasted
1+
using Base.Broadcast: Broadcast, broadcasted, Broadcasted, BroadcastStyle
22
const RCR = RuleConfig{>:HasReverseMode}
3+
const TRI_NO = (NoTangent(), NoTangent(), NoTangent())
34

45
function rrule(::typeof(copy), bc::Broadcasted)
56
uncopy(Δ) = (NoTangent(), Δ)
@@ -22,12 +23,16 @@ _print(args...) = printstyled("CR: ", join(args, " "), "\n", color=:magenta) # n
2223
# and we don't know whether re-computing `y` is cheap.
2324
# (We could check `f` first like `sum(f, x)` does, but checking whether `g` needs `y` is tricky.)
2425

25-
function rrule(cfg::RCR, ::typeof(broadcasted), f::F, args::Vararg{Any,N}) where {F,N}
26+
# This rule has `::BroadcastStyle` in part becuase Zygote's generic rule does, to avoid ambiguities.
27+
# It applies one step later in AD, and all args have `broadcastable(x)` thus many have `Ref(x)`, complicating some tests.
28+
# But it also means that the lazy rules below do not need `::RuleConfig{>:HasReverseMode}` just for dispatch.
29+
30+
function rrule(cfg::RCR, ::typeof(broadcasted), ::BroadcastStyle, f::F, args::Vararg{Any,N}) where {F,N}
2631
T = Broadcast.combine_eltypes(f, args)
2732
if T === Bool # TODO use nondifftype here
2833
# 1: Trivial case: non-differentiable output, e.g. `x .> 0`
2934
_print("split_bc_trivial", f)
30-
bc_trivial_back(_) = (NoTangent(), NoTangent(), ntuple(Returns(ZeroTangent()), length(args))...)
35+
bc_trivial_back(_) = (TRI_NO..., ntuple(Returns(ZeroTangent()), length(args))...)
3136
return f.(args...), bc_trivial_back
3237
elseif T <: Number && may_bc_derivatives(T, f, args...)
3338
# 2: Fast path: use arguments & result to find derivatives.
@@ -59,9 +64,9 @@ function split_bc_derivatives(f::F, arg) where {F}
5964
das = only(derivatives_given_output(y, f, a))
6065
dy * conj(only(das)) # possibly this * should be made nan-safe.
6166
end
62-
return (NoTangent(), NoTangent(), ProjectTo(arg)(delta))
67+
return (TRI_NO..., ProjectTo(arg)(delta))
6368
end
64-
bc_one_back(z::AbstractZero) = (NoTangent(), NoTangent(), z)
69+
bc_one_back(z::AbstractZero) = (TRI_NO..., z)
6570
return ys, bc_one_back
6671
end
6772
function split_bc_derivatives(f::F, args::Vararg{Any,N}) where {F,N}
@@ -73,9 +78,9 @@ function split_bc_derivatives(f::F, args::Vararg{Any,N}) where {F,N}
7378
map(da -> dy * conj(da), das) # possibly this * should be made nan-safe.
7479
end
7580
dargs = map(unbroadcast, args, deltas) # ideally sum in unbroadcast could be part of tuplecast?
76-
return (NoTangent(), NoTangent(), dargs...)
81+
return (TRI_NO..., dargs...)
7782
end
78-
bc_many_back(z::AbstractZero) = (NoTangent(), NoTangent(), map(Returns(z), args)...)
83+
bc_many_back(z::AbstractZero) = (TRI_NO..., map(Returns(z), args)...)
7984
return ys, bc_many_back
8085
end
8186

@@ -108,9 +113,9 @@ function split_bc_inner(frule_fun::R, cfg::RuleConfig, f::F, arg) where {R,F}
108113
delta = broadcast(ydots, unthunk(dys), arg) do ydot, dy, a
109114
ProjectTo(a)(conj(ydot) * dy) # possibly this * should be made nan-safe.
110115
end
111-
return (NoTangent(), NoTangent(), ProjectTo(arg)(delta))
116+
return (TRI_NO..., ProjectTo(arg)(delta))
112117
end
113-
back_forwards(z::AbstractZero) = (NoTangent(), NoTangent(), z)
118+
back_forwards(z::AbstractZero) = (TRI_NO..., z)
114119
return ys, back_forwards
115120
end
116121

@@ -129,32 +134,31 @@ function split_bc_pullbacks(cfg::RCR, f::F, args::Vararg{Any,N}) where {F,N}
129134
end
130135
dargs = map(unbroadcast, args, Base.tail(deltas))
131136
df = ProjectTo(f)(sum(first(deltas)))
132-
return (NoTangent(), df, dargs...)
137+
return (NoTangent(), NoTangent(), df, dargs...)
133138
end
134-
back_generic(z::AbstractZero) = (NoTangent(), NoTangent(), map(Returns(z), args)...)
139+
back_generic(z::AbstractZero) = (TRI_NO..., map(Returns(z), args)...)
135140
return ys3, back_generic
136141
end
137142

138143
# Don't run broadcasting on scalars
139-
function rrule(cfg::RCR, ::typeof(broadcasted), f::F, args::Number...) where {F}
144+
function rrule(cfg::RCR, ::typeof(broadcasted), ::BroadcastStyle, f::F, args::Number...) where {F}
140145
_print("split_bc_scalar", f)
141146
z, back = rrule_via_ad(cfg, f, args...)
142-
return z, dz -> (NoTangent(), back(dz)...)
147+
return z, dz -> (NoTangent(), NoTangent(), back(dz)...)
143148
end
144149

145150
#####
146151
##### Fused broadcasting
147152
#####
148153

149154
# For certain cheap operations we can easily allow fused broadcast; the forward pass may be run twice.
150-
# These all have `RuleConfig{>:HasReverseMode}` only for dispatch, to beat the split rule above.
151155
# Accept `x::Broadcasted` because they produce it; can't dispatch on eltype but `x` is assumed to contain `Number`s.
152156

153157
const NumericOrBroadcast = Union{Number, AbstractArray{<:Number}, NTuple{<:Any,Number}, Broadcast.Broadcasted}
154158

155159
##### Arithmetic: +, -, *, ^2, /
156160

157-
function rrule(::RCR, ::typeof(broadcasted), ::typeof(+), xs::NumericOrBroadcast...)
161+
function rrule(::typeof(broadcasted), ::typeof(+), xs::NumericOrBroadcast...)
158162
_print("plus", length(xs))
159163
function bc_plus_back(dy_raw)
160164
dy = unthunk(dy_raw)
@@ -163,7 +167,7 @@ function rrule(::RCR, ::typeof(broadcasted), ::typeof(+), xs::NumericOrBroadcast
163167
return broadcasted(+, xs...), bc_plus_back
164168
end
165169

166-
function rrule(::RCR, ::typeof(broadcasted), ::typeof(-), x::NumericOrBroadcast, y::NumericOrBroadcast)
170+
function rrule(::typeof(broadcasted), ::typeof(-), x::NumericOrBroadcast, y::NumericOrBroadcast)
167171
_print("minus 2")
168172
function bc_minus_back(dz_raw)
169173
dz = unthunk(dz_raw)
@@ -172,13 +176,13 @@ function rrule(::RCR, ::typeof(broadcasted), ::typeof(-), x::NumericOrBroadcast,
172176
return broadcasted(-, x, y), bc_minus_back
173177
end
174178

175-
function rrule(::RCR, ::typeof(broadcasted), ::typeof(-), x::NumericOrBroadcast)
179+
function rrule(::typeof(broadcasted), ::typeof(-), x::NumericOrBroadcast)
176180
_print("minus 1")
177181
bc_minus_back(dy) = (NoTangent(), NoTangent(), @thunk -unthunk(dy))
178182
return broadcasted(-, x), bc_minus_back
179183
end
180184

181-
function rrule(::RCR, ::typeof(broadcasted), ::typeof(*), x::NumericOrBroadcast, y::NumericOrBroadcast)
185+
function rrule(::typeof(broadcasted), ::typeof(*), x::NumericOrBroadcast, y::NumericOrBroadcast)
182186
_print("times")
183187
function bc_times_back(Δraw)
184188
Δ = unthunk(Δraw)
@@ -191,22 +195,20 @@ _back_star(x::Number, y, Δ) = @thunk LinearAlgebra.dot(y, Δ) # ... but this i
191195
_back_star(x::Bool, y, Δ) = NoTangent()
192196
_back_star(x::Complex{Bool}, y, Δ) = NoTangent() # e.g. for fun.(im.*x)
193197

194-
#=
195-
# This works, but not sure it improves any benchmarks.
196-
function rrule(cfg::RCR, ::typeof(broadcasted), ::typeof(*), x::NumericOrBroadcast, y::NumericOrBroadcast, zs::NumericOrBroadcast...)
198+
# This works, but not sure it improves any benchmarks. Needs corresponding scalar rule to avoid ambiguities.
199+
function rrule(::typeof(broadcasted), ::typeof(*), x::NumericOrBroadcast, y::NumericOrBroadcast, zs::NumericOrBroadcast...)
197200
_print("times", 2 + length(zs))
198-
xy, back1 = rrule(cfg, broadcasted, *, x, y)
199-
xyz, back2 = rrule(cfg, broadcasted, *, xy, zs...)
201+
xy, back1 = rrule(broadcasted, *, x, y)
202+
xyz, back2 = rrule(broadcasted, *, xy, zs...)
200203
function bc_times3_back(dxyz)
201204
_, _, dxy, dzs... = back2(dxyz)
202205
_, _, dx, dy = back1(dxy)
203206
return (NoTangent(), NoTangent(), dx, dy, dzs...)
204207
end
205208
xyz, bc_times3_back
206209
end
207-
=#
208210

209-
function rrule(::RCR, ::typeof(broadcasted), ::typeof(Base.literal_pow), ::typeof(^), x::NumericOrBroadcast, ::Val{2})
211+
function rrule(::typeof(broadcasted), ::typeof(Base.literal_pow), ::typeof(^), x::NumericOrBroadcast, ::Val{2})
210212
_print("square")
211213
function bc_square_back(dy_raw)
212214
dx = @thunk ProjectTo(x)(2 .* unthunk(dy_raw) .* conj.(x))
@@ -215,7 +217,7 @@ function rrule(::RCR, ::typeof(broadcasted), ::typeof(Base.literal_pow), ::typeo
215217
return broadcasted(Base.literal_pow, ^, x, Val(2)), bc_square_back
216218
end
217219

218-
function rrule(::RCR, ::typeof(broadcasted), ::typeof(/), x::NumericOrBroadcast, y::Number)
220+
function rrule(::typeof(broadcasted), ::typeof(/), x::NumericOrBroadcast, y::Number)
219221
_print("divide")
220222
# z = broadcast(/, x, y)
221223
z = broadcasted(/, x, y)
@@ -237,75 +239,76 @@ function _prepend_zero((y, back))
237239
return y, extra_back
238240
end
239241

240-
rrule(::RCR, ::typeof(broadcasted), ::typeof(+), args::Number...) = rrule(+, args...) |> _prepend_zero
241-
rrule(::RCR, ::typeof(broadcasted), ::typeof(-), x::Number, y::Number) = rrule(-, x, y) |> _prepend_zero
242-
rrule(::RCR, ::typeof(broadcasted), ::typeof(-), x::Number) = rrule(-, x) |> _prepend_zero
243-
rrule(::RCR, ::typeof(broadcasted), ::typeof(*), args::Number...) = rrule(*, args...) |> _prepend_zero
244-
rrule(::RCR, ::typeof(broadcasted), ::typeof(*), x::Number, y::Number) = rrule(*, x, y) |> _prepend_zero # ambiguity
245-
rrule(::RCR, ::typeof(broadcasted), ::typeof(Base.literal_pow), ::typeof(^), x::Number, ::Val{2}) =
242+
rrule(::typeof(broadcasted), ::typeof(+), args::Number...) = rrule(+, args...) |> _prepend_zero
243+
rrule(::typeof(broadcasted), ::typeof(-), x::Number, y::Number) = rrule(-, x, y) |> _prepend_zero
244+
rrule(::typeof(broadcasted), ::typeof(-), x::Number) = rrule(-, x) |> _prepend_zero
245+
rrule(::typeof(broadcasted), ::typeof(*), args::Number...) = rrule(*, args...) |> _prepend_zero
246+
rrule(::typeof(broadcasted), ::typeof(*), x::Number, y::Number) = rrule(*, x, y) |> _prepend_zero # ambiguity
247+
rrule(::typeof(broadcasted), ::typeof(*), x::Number, y::Number, zs::Number...) = rrule(*, x, y, zs...) |> _prepend_zero
248+
rrule(::typeof(broadcasted), ::typeof(Base.literal_pow), ::typeof(^), x::Number, ::Val{2}) =
246249
rrule(Base.literal_pow, ^, x, Val(2)) |> _prepend_zero
247-
rrule(::RCR, ::typeof(broadcasted), ::typeof(/), x::Number, y::Number) = rrule(/, x, y) |> _prepend_zero
250+
rrule(::typeof(broadcasted), ::typeof(/), x::Number, y::Number) = rrule(/, x, y) |> _prepend_zero
248251

249252
##### Identity, number types
250253

251-
rrule(::RCR, ::typeof(broadcasted), ::typeof(identity), x::NumericOrBroadcast) = rrule(identity, x) |> _prepend_zero
252-
rrule(::RCR, ::typeof(broadcasted), ::typeof(identity), x::Number) = rrule(identity, x) |> _prepend_zero # ambiguity
254+
rrule(::typeof(broadcasted), ::typeof(identity), x::NumericOrBroadcast) = rrule(identity, x) |> _prepend_zero
255+
rrule(::typeof(broadcasted), ::typeof(identity), x::Number) = rrule(identity, x) |> _prepend_zero # ambiguity
253256

254-
function rrule(::RCR, ::typeof(broadcasted), ::Type{T}, x::NumericOrBroadcast) where {T<:Number}
257+
function rrule(::typeof(broadcasted), ::Type{T}, x::NumericOrBroadcast) where {T<:Number}
255258
_print("bc type", T)
256259
bc_type_back(dz) = (NoTangent(), NoTangent(), @thunk(unbroadcast(x, unthunk(dz))))
257260
return broadcasted(T, x), bc_type_back
258261
end
259-
rrule(::RCR, ::typeof(broadcasted), ::Type{T}, x::Number) where {T<:Number} = rrule(T, x) |> _prepend_zero
262+
rrule(::typeof(broadcasted), ::Type{T}, x::Number) where {T<:Number} = rrule(T, x) |> _prepend_zero
260263

261-
function rrule(::RCR, ::typeof(broadcasted), ::typeof(float), x::NumericOrBroadcast)
264+
function rrule(::typeof(broadcasted), ::typeof(float), x::NumericOrBroadcast)
262265
_print("bc float")
263266
bc_float_back(dz) = (NoTangent(), NoTangent(), @thunk(unbroadcast(x, unthunk(dz))))
264267
return broadcasted(float, x), bc_float_back
265268
end
266-
rrule(::RCR, ::typeof(broadcasted), ::typeof(float), x::Number) = rrule(float, x) |> _prepend_zero
269+
rrule(::typeof(broadcasted), ::typeof(float), x::Number) = rrule(float, x) |> _prepend_zero
267270

268271
##### Complex: conj, real, imag
269272

270273
for conj in [:conj, :adjoint] # identical as we know eltype <: Number
271274
@eval begin
272-
function rrule(::RCR, ::typeof(broadcasted), ::typeof($conj), x::NumericOrBroadcast)
275+
function rrule(::typeof(broadcasted), ::typeof($conj), x::NumericOrBroadcast)
273276
bc_conj_back(dx) = (NoTangent(), NoTangent(), conj(unthunk(dx)))
274277
return broadcasted($conj, x), bc_conj_back
275278
end
276-
rrule(::RCR, ::typeof(broadcasted), ::typeof($conj), x::Number) = rrule($conj, x) |> _prepend_zero
277-
rrule(::RCR, ::typeof(broadcasted), ::typeof($conj), x::AbstractArray{<:Real}) = rrule(identity, x) |> _prepend_zero
279+
rrule(::typeof(broadcasted), ::typeof($conj), x::Number) = rrule($conj, x) |> _prepend_zero
280+
rrule(::typeof(broadcasted), ::typeof($conj), x::AbstractArray{<:Real}) = rrule(identity, x) |> _prepend_zero
278281
# This `AbstractArray{<:Real}` rule won't catch `conj.(x.+1)` with lazy `.+` rule.
279282
# Could upgrade to infer eltype of the `Broadcasted`?
280283
end
281284
end
282285

283-
function rrule(::RCR, ::typeof(broadcasted), ::typeof(real), x::NumericOrBroadcast)
286+
function rrule(::typeof(broadcasted), ::typeof(real), x::NumericOrBroadcast)
284287
_print("real")
285288
bc_real_back(dz) = (NoTangent(), NoTangent(), @thunk(real(unthunk(dz))))
286289
return broadcasted(real, x), bc_real_back
287290
end
288-
rrule(::RCR, ::typeof(broadcasted), ::typeof(real), x::Number) = rrule(real, x) |> _prepend_zero
289-
rrule(::RCR, ::typeof(broadcasted), ::typeof(real), x::AbstractArray{<:Real}) = rrule(identity, x) |> _prepend_zero
291+
rrule(::typeof(broadcasted), ::typeof(real), x::Number) = rrule(real, x) |> _prepend_zero
292+
rrule(::typeof(broadcasted), ::typeof(real), x::AbstractArray{<:Real}) = rrule(identity, x) |> _prepend_zero
290293

291-
function rrule(::RCR, ::typeof(broadcasted), ::typeof(imag), x::NumericOrBroadcast)
294+
function rrule(::typeof(broadcasted), ::typeof(imag), x::NumericOrBroadcast)
292295
_print("imag")
293296
bc_imag_back(dz) = (NoTangent(), NoTangent(), @thunk(im .* real.(unthunk(dz))))
294297
return broadcasted(imag, x), bc_imag_back
295298
end
296-
rrule(::RCR, ::typeof(broadcasted), ::typeof(imag), x::Number) = rrule(imag, x) |> _prepend_zero
297-
function rrule(::RCR, ::typeof(broadcasted), ::typeof(imag), x::AbstractArray{<:Real})
299+
rrule(::typeof(broadcasted), ::typeof(imag), x::Number) = rrule(imag, x) |> _prepend_zero
300+
function rrule(::typeof(broadcasted), ::typeof(imag), x::AbstractArray{<:Real})
298301
_print("imag(real)")
299302
bc_imag_back_2(dz) = (NoTangent(), NoTangent(), ZeroTangent())
300303
return broadcasted(imag, x), bc_imag_back_2
301304
end
302305

303-
function rrule(::RCR, ::typeof(broadcasted), ::typeof(complex), x::NumericOrBroadcast)
306+
function rrule(::typeof(broadcasted), ::typeof(complex), x::NumericOrBroadcast)
304307
_print("bc complex")
305308
bc_complex_back(dz) = (NoTangent(), NoTangent(), @thunk(unbroadcast(x, unthunk(dz))))
306309
return broadcasted(complex, x), bc_complex_back
307310
end
308-
rrule(::RCR, ::typeof(broadcasted), ::typeof(complex), x::Number) = rrule(complex, x) |> _prepend_zero
311+
rrule(::typeof(broadcasted), ::typeof(complex), x::Number) = rrule(complex, x) |> _prepend_zero
309312

310313
#####
311314
##### Shape fixing
@@ -389,8 +392,16 @@ end
389392
##### For testing
390393
#####
391394

392-
function rrule(cfg::RCR, ::typeof(copybroadcasted), f, args...)
393-
y, back = rrule(cfg, broadcasted, f, args...)
395+
function rrule(cfg::RCR, ::typeof(copybroadcasted), f_args...)
396+
tmp = rrule(cfg, broadcasted, f_args...)
397+
isnothing(tmp) && throw("rrule gave nothing")
398+
y, back = tmp
399+
return _maybe_copy(y), back
400+
end
401+
function rrule(::typeof(copybroadcasted), f_args...)
402+
tmp = rrule(broadcasted, f_args...)
403+
isnothing(tmp) && throw("rrule gave nothing")
404+
y, back = tmp
394405
return _maybe_copy(y), back
395406
end
396407

0 commit comments

Comments
 (0)