1
- using Base. Broadcast: Broadcast, broadcasted, Broadcasted
1
+ using Base. Broadcast: Broadcast, broadcasted, Broadcasted, BroadcastStyle
2
2
const RCR = RuleConfig{>: HasReverseMode }
3
+ const TRI_NO = (NoTangent (), NoTangent (), NoTangent ())
3
4
4
5
function rrule (:: typeof (copy), bc:: Broadcasted )
5
6
uncopy (Δ) = (NoTangent (), Δ)
@@ -22,12 +23,16 @@ _print(args...) = printstyled("CR: ", join(args, " "), "\n", color=:magenta) # n
22
23
# and we don't know whether re-computing `y` is cheap.
23
24
# (We could check `f` first like `sum(f, x)` does, but checking whether `g` needs `y` is tricky.)
24
25
25
- function rrule (cfg:: RCR , :: typeof (broadcasted), f:: F , args:: Vararg{Any,N} ) where {F,N}
26
+ # This rule has `::BroadcastStyle` in part becuase Zygote's generic rule does, to avoid ambiguities.
27
+ # It applies one step later in AD, and all args have `broadcastable(x)` thus many have `Ref(x)`, complicating some tests.
28
+ # But it also means that the lazy rules below do not need `::RuleConfig{>:HasReverseMode}` just for dispatch.
29
+
30
+ function rrule (cfg:: RCR , :: typeof (broadcasted), :: BroadcastStyle , f:: F , args:: Vararg{Any,N} ) where {F,N}
26
31
T = Broadcast. combine_eltypes (f, args)
27
32
if T === Bool # TODO use nondifftype here
28
33
# 1: Trivial case: non-differentiable output, e.g. `x .> 0`
29
34
_print (" split_bc_trivial" , f)
30
- bc_trivial_back (_) = (NoTangent (), NoTangent () , ntuple (Returns (ZeroTangent ()), length (args))... )
35
+ bc_trivial_back (_) = (TRI_NO ... , ntuple (Returns (ZeroTangent ()), length (args))... )
31
36
return f .(args... ), bc_trivial_back
32
37
elseif T <: Number && may_bc_derivatives (T, f, args... )
33
38
# 2: Fast path: use arguments & result to find derivatives.
@@ -59,9 +64,9 @@ function split_bc_derivatives(f::F, arg) where {F}
59
64
das = only (derivatives_given_output (y, f, a))
60
65
dy * conj (only (das)) # possibly this * should be made nan-safe.
61
66
end
62
- return (NoTangent (), NoTangent () , ProjectTo (arg)(delta))
67
+ return (TRI_NO ... , ProjectTo (arg)(delta))
63
68
end
64
- bc_one_back (z:: AbstractZero ) = (NoTangent (), NoTangent () , z)
69
+ bc_one_back (z:: AbstractZero ) = (TRI_NO ... , z)
65
70
return ys, bc_one_back
66
71
end
67
72
function split_bc_derivatives (f:: F , args:: Vararg{Any,N} ) where {F,N}
@@ -73,9 +78,9 @@ function split_bc_derivatives(f::F, args::Vararg{Any,N}) where {F,N}
73
78
map (da -> dy * conj (da), das) # possibly this * should be made nan-safe.
74
79
end
75
80
dargs = map (unbroadcast, args, deltas) # ideally sum in unbroadcast could be part of tuplecast?
76
- return (NoTangent (), NoTangent () , dargs... )
81
+ return (TRI_NO ... , dargs... )
77
82
end
78
- bc_many_back (z:: AbstractZero ) = (NoTangent (), NoTangent () , map (Returns (z), args)... )
83
+ bc_many_back (z:: AbstractZero ) = (TRI_NO ... , map (Returns (z), args)... )
79
84
return ys, bc_many_back
80
85
end
81
86
@@ -108,9 +113,9 @@ function split_bc_inner(frule_fun::R, cfg::RuleConfig, f::F, arg) where {R,F}
108
113
delta = broadcast (ydots, unthunk (dys), arg) do ydot, dy, a
109
114
ProjectTo (a)(conj (ydot) * dy) # possibly this * should be made nan-safe.
110
115
end
111
- return (NoTangent (), NoTangent () , ProjectTo (arg)(delta))
116
+ return (TRI_NO ... , ProjectTo (arg)(delta))
112
117
end
113
- back_forwards (z:: AbstractZero ) = (NoTangent (), NoTangent () , z)
118
+ back_forwards (z:: AbstractZero ) = (TRI_NO ... , z)
114
119
return ys, back_forwards
115
120
end
116
121
@@ -129,32 +134,31 @@ function split_bc_pullbacks(cfg::RCR, f::F, args::Vararg{Any,N}) where {F,N}
129
134
end
130
135
dargs = map (unbroadcast, args, Base. tail (deltas))
131
136
df = ProjectTo (f)(sum (first (deltas)))
132
- return (NoTangent (), df, dargs... )
137
+ return (NoTangent (), NoTangent (), df, dargs... )
133
138
end
134
- back_generic (z:: AbstractZero ) = (NoTangent (), NoTangent () , map (Returns (z), args)... )
139
+ back_generic (z:: AbstractZero ) = (TRI_NO ... , map (Returns (z), args)... )
135
140
return ys3, back_generic
136
141
end
137
142
138
143
# Don't run broadcasting on scalars
139
- function rrule (cfg:: RCR , :: typeof (broadcasted), f:: F , args:: Number... ) where {F}
144
+ function rrule (cfg:: RCR , :: typeof (broadcasted), :: BroadcastStyle , f:: F , args:: Number... ) where {F}
140
145
_print (" split_bc_scalar" , f)
141
146
z, back = rrule_via_ad (cfg, f, args... )
142
- return z, dz -> (NoTangent (), back (dz)... )
147
+ return z, dz -> (NoTangent (), NoTangent (), back (dz)... )
143
148
end
144
149
145
150
# ####
146
151
# #### Fused broadcasting
147
152
# ####
148
153
149
154
# For certain cheap operations we can easily allow fused broadcast; the forward pass may be run twice.
150
- # These all have `RuleConfig{>:HasReverseMode}` only for dispatch, to beat the split rule above.
151
155
# Accept `x::Broadcasted` because they produce it; can't dispatch on eltype but `x` is assumed to contain `Number`s.
152
156
153
157
const NumericOrBroadcast = Union{Number, AbstractArray{<: Number }, NTuple{<: Any ,Number}, Broadcast. Broadcasted}
154
158
155
159
# #### Arithmetic: +, -, *, ^2, /
156
160
157
- function rrule (:: RCR , :: typeof (broadcasted), :: typeof (+ ), xs:: NumericOrBroadcast... )
161
+ function rrule (:: typeof (broadcasted), :: typeof (+ ), xs:: NumericOrBroadcast... )
158
162
_print (" plus" , length (xs))
159
163
function bc_plus_back (dy_raw)
160
164
dy = unthunk (dy_raw)
@@ -163,7 +167,7 @@ function rrule(::RCR, ::typeof(broadcasted), ::typeof(+), xs::NumericOrBroadcast
163
167
return broadcasted (+ , xs... ), bc_plus_back
164
168
end
165
169
166
- function rrule (:: RCR , :: typeof (broadcasted), :: typeof (- ), x:: NumericOrBroadcast , y:: NumericOrBroadcast )
170
+ function rrule (:: typeof (broadcasted), :: typeof (- ), x:: NumericOrBroadcast , y:: NumericOrBroadcast )
167
171
_print (" minus 2" )
168
172
function bc_minus_back (dz_raw)
169
173
dz = unthunk (dz_raw)
@@ -172,13 +176,13 @@ function rrule(::RCR, ::typeof(broadcasted), ::typeof(-), x::NumericOrBroadcast,
172
176
return broadcasted (- , x, y), bc_minus_back
173
177
end
174
178
175
- function rrule (:: RCR , :: typeof (broadcasted), :: typeof (- ), x:: NumericOrBroadcast )
179
+ function rrule (:: typeof (broadcasted), :: typeof (- ), x:: NumericOrBroadcast )
176
180
_print (" minus 1" )
177
181
bc_minus_back (dy) = (NoTangent (), NoTangent (), @thunk - unthunk (dy))
178
182
return broadcasted (- , x), bc_minus_back
179
183
end
180
184
181
- function rrule (:: RCR , :: typeof (broadcasted), :: typeof (* ), x:: NumericOrBroadcast , y:: NumericOrBroadcast )
185
+ function rrule (:: typeof (broadcasted), :: typeof (* ), x:: NumericOrBroadcast , y:: NumericOrBroadcast )
182
186
_print (" times" )
183
187
function bc_times_back (Δraw)
184
188
Δ = unthunk (Δraw)
@@ -191,22 +195,20 @@ _back_star(x::Number, y, Δ) = @thunk LinearAlgebra.dot(y, Δ) # ... but this i
191
195
_back_star (x:: Bool , y, Δ) = NoTangent ()
192
196
_back_star (x:: Complex{Bool} , y, Δ) = NoTangent () # e.g. for fun.(im.*x)
193
197
194
- #=
195
- # This works, but not sure it improves any benchmarks.
196
- function rrule(cfg::RCR, ::typeof(broadcasted), ::typeof(*), x::NumericOrBroadcast, y::NumericOrBroadcast, zs::NumericOrBroadcast...)
198
+ # This works, but not sure it improves any benchmarks. Needs corresponding scalar rule to avoid ambiguities.
199
+ function rrule (:: typeof (broadcasted), :: typeof (* ), x:: NumericOrBroadcast , y:: NumericOrBroadcast , zs:: NumericOrBroadcast... )
197
200
_print (" times" , 2 + length (zs))
198
- xy, back1 = rrule(cfg, broadcasted, *, x, y)
199
- xyz, back2 = rrule(cfg, broadcasted, *, xy, zs...)
201
+ xy, back1 = rrule (broadcasted, * , x, y)
202
+ xyz, back2 = rrule (broadcasted, * , xy, zs... )
200
203
function bc_times3_back (dxyz)
201
204
_, _, dxy, dzs... = back2 (dxyz)
202
205
_, _, dx, dy = back1 (dxy)
203
206
return (NoTangent (), NoTangent (), dx, dy, dzs... )
204
207
end
205
208
xyz, bc_times3_back
206
209
end
207
- =#
208
210
209
- function rrule (:: RCR , :: typeof (broadcasted), :: typeof (Base. literal_pow), :: typeof (^ ), x:: NumericOrBroadcast , :: Val{2} )
211
+ function rrule (:: typeof (broadcasted), :: typeof (Base. literal_pow), :: typeof (^ ), x:: NumericOrBroadcast , :: Val{2} )
210
212
_print (" square" )
211
213
function bc_square_back (dy_raw)
212
214
dx = @thunk ProjectTo (x)(2 .* unthunk (dy_raw) .* conj .(x))
@@ -215,7 +217,7 @@ function rrule(::RCR, ::typeof(broadcasted), ::typeof(Base.literal_pow), ::typeo
215
217
return broadcasted (Base. literal_pow, ^ , x, Val (2 )), bc_square_back
216
218
end
217
219
218
- function rrule (:: RCR , :: typeof (broadcasted), :: typeof (/ ), x:: NumericOrBroadcast , y:: Number )
220
+ function rrule (:: typeof (broadcasted), :: typeof (/ ), x:: NumericOrBroadcast , y:: Number )
219
221
_print (" divide" )
220
222
# z = broadcast(/, x, y)
221
223
z = broadcasted (/ , x, y)
@@ -237,75 +239,76 @@ function _prepend_zero((y, back))
237
239
return y, extra_back
238
240
end
239
241
240
- rrule (:: RCR , :: typeof (broadcasted), :: typeof (+ ), args:: Number... ) = rrule (+ , args... ) |> _prepend_zero
241
- rrule (:: RCR , :: typeof (broadcasted), :: typeof (- ), x:: Number , y:: Number ) = rrule (- , x, y) |> _prepend_zero
242
- rrule (:: RCR , :: typeof (broadcasted), :: typeof (- ), x:: Number ) = rrule (- , x) |> _prepend_zero
243
- rrule (:: RCR , :: typeof (broadcasted), :: typeof (* ), args:: Number... ) = rrule (* , args... ) |> _prepend_zero
244
- rrule (:: RCR , :: typeof (broadcasted), :: typeof (* ), x:: Number , y:: Number ) = rrule (* , x, y) |> _prepend_zero # ambiguity
245
- rrule (:: RCR , :: typeof (broadcasted), :: typeof (Base. literal_pow), :: typeof (^ ), x:: Number , :: Val{2} ) =
242
+ rrule (:: typeof (broadcasted), :: typeof (+ ), args:: Number... ) = rrule (+ , args... ) |> _prepend_zero
243
+ rrule (:: typeof (broadcasted), :: typeof (- ), x:: Number , y:: Number ) = rrule (- , x, y) |> _prepend_zero
244
+ rrule (:: typeof (broadcasted), :: typeof (- ), x:: Number ) = rrule (- , x) |> _prepend_zero
245
+ rrule (:: typeof (broadcasted), :: typeof (* ), args:: Number... ) = rrule (* , args... ) |> _prepend_zero
246
+ rrule (:: typeof (broadcasted), :: typeof (* ), x:: Number , y:: Number ) = rrule (* , x, y) |> _prepend_zero # ambiguity
247
+ rrule (:: typeof (broadcasted), :: typeof (* ), x:: Number , y:: Number , zs:: Number... ) = rrule (* , x, y, zs... ) |> _prepend_zero
248
+ rrule (:: typeof (broadcasted), :: typeof (Base. literal_pow), :: typeof (^ ), x:: Number , :: Val{2} ) =
246
249
rrule (Base. literal_pow, ^ , x, Val (2 )) |> _prepend_zero
247
- rrule (:: RCR , :: typeof (broadcasted), :: typeof (/ ), x:: Number , y:: Number ) = rrule (/ , x, y) |> _prepend_zero
250
+ rrule (:: typeof (broadcasted), :: typeof (/ ), x:: Number , y:: Number ) = rrule (/ , x, y) |> _prepend_zero
248
251
249
252
# #### Identity, number types
250
253
251
- rrule (:: RCR , :: typeof (broadcasted), :: typeof (identity), x:: NumericOrBroadcast ) = rrule (identity, x) |> _prepend_zero
252
- rrule (:: RCR , :: typeof (broadcasted), :: typeof (identity), x:: Number ) = rrule (identity, x) |> _prepend_zero # ambiguity
254
+ rrule (:: typeof (broadcasted), :: typeof (identity), x:: NumericOrBroadcast ) = rrule (identity, x) |> _prepend_zero
255
+ rrule (:: typeof (broadcasted), :: typeof (identity), x:: Number ) = rrule (identity, x) |> _prepend_zero # ambiguity
253
256
254
- function rrule (:: RCR , :: typeof (broadcasted), :: Type{T} , x:: NumericOrBroadcast ) where {T<: Number }
257
+ function rrule (:: typeof (broadcasted), :: Type{T} , x:: NumericOrBroadcast ) where {T<: Number }
255
258
_print (" bc type" , T)
256
259
bc_type_back (dz) = (NoTangent (), NoTangent (), @thunk (unbroadcast (x, unthunk (dz))))
257
260
return broadcasted (T, x), bc_type_back
258
261
end
259
- rrule (:: RCR , :: typeof (broadcasted), :: Type{T} , x:: Number ) where {T<: Number } = rrule (T, x) |> _prepend_zero
262
+ rrule (:: typeof (broadcasted), :: Type{T} , x:: Number ) where {T<: Number } = rrule (T, x) |> _prepend_zero
260
263
261
- function rrule (:: RCR , :: typeof (broadcasted), :: typeof (float), x:: NumericOrBroadcast )
264
+ function rrule (:: typeof (broadcasted), :: typeof (float), x:: NumericOrBroadcast )
262
265
_print (" bc float" )
263
266
bc_float_back (dz) = (NoTangent (), NoTangent (), @thunk (unbroadcast (x, unthunk (dz))))
264
267
return broadcasted (float, x), bc_float_back
265
268
end
266
- rrule (:: RCR , :: typeof (broadcasted), :: typeof (float), x:: Number ) = rrule (float, x) |> _prepend_zero
269
+ rrule (:: typeof (broadcasted), :: typeof (float), x:: Number ) = rrule (float, x) |> _prepend_zero
267
270
268
271
# #### Complex: conj, real, imag
269
272
270
273
for conj in [:conj , :adjoint ] # identical as we know eltype <: Number
271
274
@eval begin
272
- function rrule (:: RCR , :: typeof (broadcasted), :: typeof ($ conj), x:: NumericOrBroadcast )
275
+ function rrule (:: typeof (broadcasted), :: typeof ($ conj), x:: NumericOrBroadcast )
273
276
bc_conj_back (dx) = (NoTangent (), NoTangent (), conj (unthunk (dx)))
274
277
return broadcasted ($ conj, x), bc_conj_back
275
278
end
276
- rrule (:: RCR , :: typeof (broadcasted), :: typeof ($ conj), x:: Number ) = rrule ($ conj, x) |> _prepend_zero
277
- rrule (:: RCR , :: typeof (broadcasted), :: typeof ($ conj), x:: AbstractArray{<:Real} ) = rrule (identity, x) |> _prepend_zero
279
+ rrule (:: typeof (broadcasted), :: typeof ($ conj), x:: Number ) = rrule ($ conj, x) |> _prepend_zero
280
+ rrule (:: typeof (broadcasted), :: typeof ($ conj), x:: AbstractArray{<:Real} ) = rrule (identity, x) |> _prepend_zero
278
281
# This `AbstractArray{<:Real}` rule won't catch `conj.(x.+1)` with lazy `.+` rule.
279
282
# Could upgrade to infer eltype of the `Broadcasted`?
280
283
end
281
284
end
282
285
283
- function rrule (:: RCR , :: typeof (broadcasted), :: typeof (real), x:: NumericOrBroadcast )
286
+ function rrule (:: typeof (broadcasted), :: typeof (real), x:: NumericOrBroadcast )
284
287
_print (" real" )
285
288
bc_real_back (dz) = (NoTangent (), NoTangent (), @thunk (real (unthunk (dz))))
286
289
return broadcasted (real, x), bc_real_back
287
290
end
288
- rrule (:: RCR , :: typeof (broadcasted), :: typeof (real), x:: Number ) = rrule (real, x) |> _prepend_zero
289
- rrule (:: RCR , :: typeof (broadcasted), :: typeof (real), x:: AbstractArray{<:Real} ) = rrule (identity, x) |> _prepend_zero
291
+ rrule (:: typeof (broadcasted), :: typeof (real), x:: Number ) = rrule (real, x) |> _prepend_zero
292
+ rrule (:: typeof (broadcasted), :: typeof (real), x:: AbstractArray{<:Real} ) = rrule (identity, x) |> _prepend_zero
290
293
291
- function rrule (:: RCR , :: typeof (broadcasted), :: typeof (imag), x:: NumericOrBroadcast )
294
+ function rrule (:: typeof (broadcasted), :: typeof (imag), x:: NumericOrBroadcast )
292
295
_print (" imag" )
293
296
bc_imag_back (dz) = (NoTangent (), NoTangent (), @thunk (im .* real .(unthunk (dz))))
294
297
return broadcasted (imag, x), bc_imag_back
295
298
end
296
- rrule (:: RCR , :: typeof (broadcasted), :: typeof (imag), x:: Number ) = rrule (imag, x) |> _prepend_zero
297
- function rrule (:: RCR , :: typeof (broadcasted), :: typeof (imag), x:: AbstractArray{<:Real} )
299
+ rrule (:: typeof (broadcasted), :: typeof (imag), x:: Number ) = rrule (imag, x) |> _prepend_zero
300
+ function rrule (:: typeof (broadcasted), :: typeof (imag), x:: AbstractArray{<:Real} )
298
301
_print (" imag(real)" )
299
302
bc_imag_back_2 (dz) = (NoTangent (), NoTangent (), ZeroTangent ())
300
303
return broadcasted (imag, x), bc_imag_back_2
301
304
end
302
305
303
- function rrule (:: RCR , :: typeof (broadcasted), :: typeof (complex), x:: NumericOrBroadcast )
306
+ function rrule (:: typeof (broadcasted), :: typeof (complex), x:: NumericOrBroadcast )
304
307
_print (" bc complex" )
305
308
bc_complex_back (dz) = (NoTangent (), NoTangent (), @thunk (unbroadcast (x, unthunk (dz))))
306
309
return broadcasted (complex, x), bc_complex_back
307
310
end
308
- rrule (:: RCR , :: typeof (broadcasted), :: typeof (complex), x:: Number ) = rrule (complex, x) |> _prepend_zero
311
+ rrule (:: typeof (broadcasted), :: typeof (complex), x:: Number ) = rrule (complex, x) |> _prepend_zero
309
312
310
313
# ####
311
314
# #### Shape fixing
389
392
# #### For testing
390
393
# ####
391
394
392
- function rrule (cfg:: RCR , :: typeof (copy∘ broadcasted), f, args... )
393
- y, back = rrule (cfg, broadcasted, f, args... )
395
+ function rrule (cfg:: RCR , :: typeof (copy∘ broadcasted), f_args... )
396
+ tmp = rrule (cfg, broadcasted, f_args... )
397
+ isnothing (tmp) && throw (" rrule gave nothing" )
398
+ y, back = tmp
399
+ return _maybe_copy (y), back
400
+ end
401
+ function rrule (:: typeof (copy∘ broadcasted), f_args... )
402
+ tmp = rrule (broadcasted, f_args... )
403
+ isnothing (tmp) && throw (" rrule gave nothing" )
404
+ y, back = tmp
394
405
return _maybe_copy (y), back
395
406
end
396
407
0 commit comments