1
1
using Base. Broadcast: Broadcast, broadcasted, Broadcasted
2
2
const RCR = RuleConfig{>: HasReverseMode }
3
3
4
- rrule (:: typeof (copy), bc:: Broadcasted ) = copy (bc), Δ -> (NoTangent (), Δ)
4
+ function rrule (:: typeof (copy), bc:: Broadcasted )
5
+ uncopy (Δ) = (NoTangent (), Δ)
6
+ return copy (bc), uncopy
7
+ end
5
8
6
9
# Skip AD'ing through the axis computation
7
10
function rrule (:: typeof (Broadcast. instantiate), bc:: Broadcasted )
8
- uninstantiate (Δ) = Core . tuple (NoTangent (), Δ)
11
+ uninstantiate (Δ) = (NoTangent (), Δ)
9
12
return Broadcast. instantiate (bc), uninstantiate
10
13
end
11
14
12
- _print (args... ) = nothing # println(join(args, " "))
15
+ _print (args... ) = nothing # println(join(args, " ")) #
13
16
14
17
# ####
15
18
# #### Split broadcasting
69
72
70
73
# Don't run broadcasting on scalars
71
74
function rrule (cfg:: RCR , :: typeof (broadcasted), f:: F , args:: Number... ) where {F}
72
- # function split_bc_rule(cfg::RCR, f::F, args::Number...) where {F}
73
- _print (" split_bc_rule scalar" , f)
75
+ _print (" split_bc_scalar" , f)
74
76
z, back = rrule_via_ad (cfg, f, args... )
75
77
return z, dz -> (NoTangent (), back (dz)... )
76
78
end
77
79
78
- # using StructArrays
79
- #
80
- # function tuplecast(f::F, args...) where {F}
81
- # T = Broadcast.combine_eltypes(f, args)
82
- # if isconcretetype(T)
83
- # T <: Tuple || throw(ArgumentError("tuplecast(f, args) only works on functions returning a tuple."))
84
- # end
85
- # bc = Broadcast.instantiate(Broadcast.broadcasted(f, args...))
86
- # StructArrays.components(StructArray(bc))
87
- # end
88
-
89
80
# ####
90
81
# #### Fused broadcasting
91
82
# ####
92
83
93
- # For certain cheap operations we can easily allow fused broadcast.
94
- # These all have `RuleConfig{>:HasReverseMode}` as otherwise the split rule matches first & they are not used.
95
- # They accept `Broadcasted` because they produce it; it has no eltype but is assumed to contain `Number`s.
84
+ # For certain cheap operations we can easily allow fused broadcast; the forward pass may be run twice.
85
+ # These all have `RuleConfig{>:HasReverseMode}` only for dispatch, to beat the split rule above.
86
+ # Accept `x::Broadcasted` because they produce it; can't dispatch on eltype but `x` is assumed to contain `Number`s.
87
+
96
88
const NumericOrBroadcast = Union{Number, AbstractArray{<: Number }, NTuple{<: Any ,Number}, Broadcast. Broadcasted}
97
89
90
+ # #### Arithmetic: +, -, *, ^2, /
91
+
98
92
function rrule (:: RCR , :: typeof (broadcasted), :: typeof (+ ), xs:: NumericOrBroadcast... )
99
93
_print (" plus" , length (xs))
100
94
function bc_plus_back (dy_raw)
101
95
dy = unthunk (dy_raw)
102
- (NoTangent (), NoTangent (), map (x -> unbroadcast (x, dy), xs)... )
96
+ return (NoTangent (), NoTangent (), map (x -> unbroadcast (x, dy), xs)... ) # no copies, this may return dx2 === dx3
103
97
end
104
98
return broadcasted (+ , xs... ), bc_plus_back
105
99
end
106
100
107
101
function rrule (:: RCR , :: typeof (broadcasted), :: typeof (- ), x:: NumericOrBroadcast , y:: NumericOrBroadcast )
108
102
_print (" minus 2" )
109
- bc_minus_back (Δraw) = let Δ = unthunk (Δraw)
110
- (NoTangent (), NoTangent (), @thunk (unbroadcast (x, Δ)), @thunk (- unbroadcast (y, Δ)))
103
+ function bc_minus_back (dz_raw)
104
+ dz = unthunk (dz_raw)
105
+ return (NoTangent (), NoTangent (), @thunk (unbroadcast (x, dz)), @thunk (- unbroadcast (y, dz)))
111
106
end
112
107
return broadcasted (- , x, y), bc_minus_back
113
108
end
@@ -118,46 +113,59 @@ function rrule(::RCR, ::typeof(broadcasted), ::typeof(-), x::NumericOrBroadcast)
118
113
return broadcasted (- , x), bc_minus_back
119
114
end
120
115
121
- using LinearAlgebra: dot
122
-
123
116
function rrule (:: RCR , :: typeof (broadcasted), :: typeof (* ), x:: NumericOrBroadcast , y:: NumericOrBroadcast )
124
117
_print (" times" )
125
118
function bc_times_back (Δraw)
126
119
Δ = unthunk (Δraw)
127
- (NoTangent (), NoTangent (), _back_star (x, y, Δ), _back_star (y, x, Δ))
120
+ return (NoTangent (), NoTangent (), _back_star (x, y, Δ), _back_star (y, x, Δ))
128
121
end
129
122
return broadcasted (* , x, y), bc_times_back
130
123
end
131
- _back_star (x, y, Δ) = @thunk unbroadcast (x, Δ .* conj .(y))
132
- _back_star (x:: Number , y, Δ) = @thunk dot (y, Δ)
124
+ _back_star (x, y, Δ) = @thunk unbroadcast (x, Δ .* conj .(y)) # this case probably isn't better than generic
125
+ _back_star (x:: Number , y, Δ) = @thunk LinearAlgebra . dot (y, Δ) # ... but this is why the rule exists
133
126
_back_star (x:: Bool , y, Δ) = NoTangent ()
134
127
_back_star (x:: Complex{Bool} , y, Δ) = NoTangent () # e.g. for fun.(im.*x)
135
128
136
- # TODO check what happens for A * B * C
129
+ #=
130
+ # This works, but not sure it improves any benchmarks.
131
+ function rrule(cfg::RCR, ::typeof(broadcasted), ::typeof(*), x::NumericOrBroadcast, y::NumericOrBroadcast, zs::NumericOrBroadcast...)
132
+ _print("times", 2 + length(zs))
133
+ xy, back1 = rrule(cfg, broadcasted, *, x, y)
134
+ xyz, back2 = rrule(cfg, broadcasted, *, xy, zs...)
135
+ function bc_times3_back(dxyz)
136
+ _, _, dxy, dzs... = back2(dxyz)
137
+ _, _, dx, dy = back1(dxy)
138
+ return (NoTangent(), NoTangent(), dx, dy, dzs...)
139
+ end
140
+ xyz, bc_times3_back
141
+ end
142
+ =#
137
143
138
144
function rrule (:: RCR , :: typeof (broadcasted), :: typeof (Base. literal_pow), :: typeof (^ ), x:: NumericOrBroadcast , :: Val{2} )
139
145
_print (" square" )
140
146
function bc_square_back (dy_raw)
141
147
dx = @thunk ProjectTo (x)(2 .* unthunk (dy_raw) .* conj .(x))
142
- (NoTangent (), NoTangent (), NoTangent (), dx, NoTangent ())
148
+ return (NoTangent (), NoTangent (), NoTangent (), dx, NoTangent ())
143
149
end
144
150
return broadcasted (Base. literal_pow, ^ , x, Val (2 )), bc_square_back
145
151
end
146
152
147
153
function rrule (:: RCR , :: typeof (broadcasted), :: typeof (/ ), x:: NumericOrBroadcast , y:: Number )
148
154
_print (" divide" )
149
- z = broadcast (/ , x, y)
150
- function bc_divide_back (Δraw)
151
- Δ = unthunk (Δraw)
152
- dx = @thunk unbroadcast (x, Δ ./ conj .(y))
153
- dy = @thunk - dot (z, Δ) / (conj (y)) # the reason to be eager is to allow dot here
155
+ # z = broadcast(/, x, y)
156
+ z = broadcasted (/ , x, y)
157
+ function bc_divide_back (dz_raw)
158
+ dz = unthunk (dz_raw)
159
+ dx = @thunk unbroadcast (x, dz ./ conj .(y))
160
+ # dy = @thunk -LinearAlgebra.dot(z, dz) / conj(y) # the reason to be eager is to allow dot here
161
+ dy = @thunk - sum (Broadcast. instantiate (broadcasted (* , broadcasted (conj, z), dz))) / conj (y) # complete sum is fast?
154
162
(NoTangent (), NoTangent (), dx, dy)
155
163
end
156
164
return z, bc_divide_back
157
165
end
158
166
159
167
# For the same functions, send accidental broadcasting over numbers directly to `rrule`.
160
- # Could perhaps move all to @scalar_rule?
168
+ # ( Could perhaps move all to @scalar_rule?)
161
169
162
170
function _prepend_zero ((y, back))
163
171
extra_back (dy) = (NoTangent (), back (dy)... )
@@ -172,33 +180,74 @@ rrule(::RCR, ::typeof(broadcasted), ::typeof(Base.literal_pow), ::typeof(^), x::
172
180
rrule (Base. literal_pow, ^ , x, Val (2 )) |> _prepend_zero
173
181
rrule (:: RCR , :: typeof (broadcasted), :: typeof (/ ), x:: Number , y:: Number ) = rrule (/ , x, y) |> _prepend_zero
174
182
175
- # A few more cheap functions
183
+ # #### Identity, number types
176
184
177
185
rrule (:: RCR , :: typeof (broadcasted), :: typeof (identity), x:: NumericOrBroadcast ) = rrule (identity, x) |> _prepend_zero
178
186
rrule (:: RCR , :: typeof (broadcasted), :: typeof (identity), x:: Number ) = rrule (identity, x) |> _prepend_zero # ambiguity
179
187
180
- function rrule (:: RCR , :: typeof (broadcasted), :: typeof (conj), x:: NumericOrBroadcast )
181
- bc_conj_back (dx) = (NoTangent (), NoTangent (), conj (unthunk (dx)))
182
- return broadcasted (conj, x), bc_conj_back
188
+ function rrule (:: RCR , :: typeof (broadcasted), :: Type{T} , x:: NumericOrBroadcast ) where {T<: Number }
189
+ _print (" bc type" , T)
190
+ bc_type_back (dz) = (NoTangent (), NoTangent (), @thunk (unbroadcast (x, unthunk (dz))))
191
+ return broadcasted (T, x), bc_type_back
192
+ end
193
+ rrule (:: RCR , :: typeof (broadcasted), :: Type{T} , x:: Number ) where {T<: Number } = rrule (T, x) |> _prepend_zero
194
+
195
+ function rrule (:: RCR , :: typeof (broadcasted), :: typeof (float), x:: NumericOrBroadcast )
196
+ _print (" bc float" )
197
+ bc_float_back (dz) = (NoTangent (), NoTangent (), @thunk (unbroadcast (x, unthunk (dz))))
198
+ return broadcasted (float, x), bc_float_back
183
199
end
184
- rrule (:: RCR , :: typeof (broadcasted), :: typeof (conj), x:: Number ) = rrule (conj, x) |> _prepend_zero
185
- rrule (:: RCR , :: typeof (broadcasted), :: typeof (conj), x:: AbstractArray{<:Real} ) = rrule (identity, x) |> _prepend_zero
200
+ rrule (:: RCR , :: typeof (broadcasted), :: typeof (float), x:: Number ) = rrule (float, x) |> _prepend_zero
186
201
187
- # TODO real, imag
202
+ # #### Complex: conj, real, imag
203
+
204
+ for conj in [:conj , :adjoint ] # identical as we know eltype <: Number
205
+ @eval begin
206
+ function rrule (:: RCR , :: typeof (broadcasted), :: typeof ($ conj), x:: NumericOrBroadcast )
207
+ bc_conj_back (dx) = (NoTangent (), NoTangent (), conj (unthunk (dx)))
208
+ return broadcasted ($ conj, x), bc_conj_back
209
+ end
210
+ rrule (:: RCR , :: typeof (broadcasted), :: typeof ($ conj), x:: Number ) = rrule ($ conj, x) |> _prepend_zero
211
+ rrule (:: RCR , :: typeof (broadcasted), :: typeof ($ conj), x:: AbstractArray{<:Real} ) = rrule (identity, x) |> _prepend_zero
212
+ # This `AbstractArray{<:Real}` rule won't catch `conj.(x.+1)` with lazy `.+` rule.
213
+ # Could upgrade to infer eltype of the `Broadcasted`?
214
+ end
215
+ end
216
+
217
+ function rrule (:: RCR , :: typeof (broadcasted), :: typeof (real), x:: NumericOrBroadcast )
218
+ _print (" real" )
219
+ bc_real_back (dz) = (NoTangent (), NoTangent (), @thunk (real (unthunk (dz))))
220
+ return broadcasted (real, x), bc_real_back
221
+ end
222
+ rrule (:: RCR , :: typeof (broadcasted), :: typeof (real), x:: Number ) = rrule (real, x) |> _prepend_zero
223
+ rrule (:: RCR , :: typeof (broadcasted), :: typeof (real), x:: AbstractArray{<:Real} ) = rrule (identity, x) |> _prepend_zero
224
+
225
+ function rrule (:: RCR , :: typeof (broadcasted), :: typeof (imag), x:: NumericOrBroadcast )
226
+ _print (" imag" )
227
+ bc_imag_back (dz) = (NoTangent (), NoTangent (), @thunk (im .* real .(unthunk (dz))))
228
+ return broadcasted (imag, x), bc_imag_back
229
+ end
230
+ rrule (:: RCR , :: typeof (broadcasted), :: typeof (imag), x:: Number ) = rrule (imag, x) |> _prepend_zero
231
+ function rrule (:: RCR , :: typeof (broadcasted), :: typeof (imag), x:: AbstractArray{<:Real} )
232
+ _print (" imag(real)" )
233
+ bc_imag_back_2 (dz) = (NoTangent (), NoTangent (), ZeroTangent ())
234
+ return broadcasted (imag, x), bc_imag_back_2
235
+ end
188
236
189
237
# ####
190
238
# #### Shape fixing
191
239
# ####
192
240
193
- # Reverse mode broadcasting uses `unbroadcast` to reduce to correct shape:
241
+ # When sizes disagree, broadcasting gradient uses `unbroadcast` to reduce to correct shape.
242
+ # It's sometimes a little wasteful to allocate a too-large `dx`, but difficult to make more efficient.
194
243
195
244
function unbroadcast (x:: Base.AbstractArrayOrBroadcasted , dx)
196
245
N = ndims (dx)
197
246
if length (x) == length (dx)
198
247
ProjectTo (x)(dx) # handles trivial reshapes, offsets, structured matrices, row vectors
199
248
else
200
249
dims = ntuple (d -> get (size (x), d, 1 ) == 1 ? d : N+ 1 , N) # hack to get type-stable `dims`
201
- ProjectTo (x)(sum (dx; dims)) # ideally this sum might be thunked?
250
+ ProjectTo (x)(sum (dx; dims))
202
251
end
203
252
end
204
253
unbroadcast (x:: Base.AbstractArrayOrBroadcasted , dx:: AbstractZero ) = dx
0 commit comments