@@ -13,7 +13,6 @@ function rrule(::typeof(Broadcast.instantiate), bc::Broadcasted)
13
13
return Broadcast. instantiate (bc), uninstantiate
14
14
end
15
15
16
- _print (args... ) = printstyled (" CR: " , join (args, " " ), " \n " , color= :magenta ) # nothing #
17
16
18
17
# ####
19
18
# #### Split broadcasting
@@ -31,7 +30,7 @@ function rrule(cfg::RCR, ::typeof(broadcasted), ::BroadcastStyle, f::F, args::Va
31
30
T = Broadcast. combine_eltypes (f, args)
32
31
if T === Bool # TODO use nondifftype here
33
32
# 1: Trivial case: non-differentiable output, e.g. `x .> 0`
34
- _print ( " split_bc_trivial " , f)
33
+ @debug ( " split broadcasting trivial " , f, T )
35
34
bc_trivial_back (_) = (TRI_NO... , ntuple (Returns (ZeroTangent ()), length (args))... )
36
35
return f .(args... ), bc_trivial_back
37
36
elseif T <: Number && may_bc_derivatives (T, f, args... )
@@ -57,7 +56,7 @@ _eltype(x) = eltype(x) # ... but try harder to avoid `eltype(Broadcast.broadcas
57
56
_eltype (bc:: Broadcast.Broadcasted ) = Broadcast. combine_eltypes (bc. f, bc. args)
58
57
59
58
function split_bc_derivatives (f:: F , arg) where {F}
60
- _print ( " split_bc_derivative " , f)
59
+ @debug ( " split broadcasting derivative " , f)
61
60
ys = f .(arg)
62
61
function bc_one_back (dys) # For f.(x) we do not need StructArrays / unzip at all
63
62
delta = broadcast (unthunk (dys), ys, arg) do dy, y, a
@@ -70,7 +69,7 @@ function split_bc_derivatives(f::F, arg) where {F}
70
69
return ys, bc_one_back
71
70
end
72
71
function split_bc_derivatives (f:: F , args:: Vararg{Any,N} ) where {F,N}
73
- _print ( " split_bc_derivatives " , f, N)
72
+ @debug ( " split broadcasting derivatives " , f, N)
74
73
ys = f .(args... )
75
74
function bc_many_back (dys)
76
75
deltas = tuplecast (unthunk (dys), ys, args... ) do dy, y, as...
105
104
split_bc_forwards (cfg:: RuleConfig{>:HasForwardsMode} , f:: F , arg) where {F} = split_bc_inner (frule_via_ad, cfg, f, arg)
106
105
split_bc_forwards (cfg:: RuleConfig , f:: F , arg) where {F} = split_bc_inner (frule, cfg, f, arg)
107
106
function split_bc_inner (frule_fun:: R , cfg:: RuleConfig , f:: F , arg) where {R,F}
108
- _print ( " split_bc_forwards " , frule_fun, f)
107
+ @debug ( " split broadcasting forwards " , frule_fun, f)
109
108
ys, ydots = tuplecast (arg) do a
110
109
frule_fun (cfg, (NoTangent (), one (a)), f, a)
111
110
end
124
123
# can change the number of calls, don't bother to try to reverse the iteration.
125
124
126
125
function split_bc_pullbacks (cfg:: RCR , f:: F , args:: Vararg{Any,N} ) where {F,N}
127
- _print ( " split_bc_generic " , f, N)
126
+ @debug ( " split broadcasting generic " , f, N)
128
127
ys3, backs = tuplecast (args... ) do a...
129
128
rrule_via_ad (cfg, f, a... )
130
129
end
142
141
143
142
# Don't run broadcasting on scalars
144
143
function rrule (cfg:: RCR , :: typeof (broadcasted), :: BroadcastStyle , f:: F , args:: Number... ) where {F}
145
- _print ( " split_bc_scalar " , f)
144
+ @debug ( " split broadcasting scalar " , f)
146
145
z, back = rrule_via_ad (cfg, f, args... )
147
146
return z, dz -> (NoTangent (), NoTangent (), back (dz)... )
148
147
end
@@ -159,7 +158,7 @@ const NumericOrBroadcast = Union{Number, AbstractArray{<:Number}, NTuple{<:Any,N
159
158
# #### Arithmetic: +, -, *, ^2, /
160
159
161
160
function rrule (:: typeof (broadcasted), :: typeof (+ ), xs:: NumericOrBroadcast... )
162
- _print ( " plus" , length (xs))
161
+ @debug ( " broadcasting: plus" , length (xs))
163
162
function bc_plus_back (dy_raw)
164
163
dy = unthunk (dy_raw)
165
164
return (NoTangent (), NoTangent (), map (x -> unbroadcast (x, dy), xs)... ) # no copies, this may return dx2 === dx3
@@ -168,7 +167,7 @@ function rrule(::typeof(broadcasted), ::typeof(+), xs::NumericOrBroadcast...)
168
167
end
169
168
170
169
function rrule (:: typeof (broadcasted), :: typeof (- ), x:: NumericOrBroadcast , y:: NumericOrBroadcast )
171
- _print ( " minus 2" )
170
+ @debug ( " broadcasting: minus 2" )
172
171
function bc_minus_back (dz_raw)
173
172
dz = unthunk (dz_raw)
174
173
return (NoTangent (), NoTangent (), @thunk (unbroadcast (x, dz)), @thunk (- unbroadcast (y, dz)))
@@ -177,13 +176,13 @@ function rrule(::typeof(broadcasted), ::typeof(-), x::NumericOrBroadcast, y::Num
177
176
end
178
177
179
178
function rrule (:: typeof (broadcasted), :: typeof (- ), x:: NumericOrBroadcast )
180
- _print ( " minus 1" )
179
+ @debug ( " broadcasting: minus 1" )
181
180
bc_minus_back (dy) = (NoTangent (), NoTangent (), @thunk - unthunk (dy))
182
181
return broadcasted (- , x), bc_minus_back
183
182
end
184
183
185
184
function rrule (:: typeof (broadcasted), :: typeof (* ), x:: NumericOrBroadcast , y:: NumericOrBroadcast )
186
- _print ( " times" )
185
+ @debug ( " broadcasting: times" )
187
186
function bc_times_back (Δraw)
188
187
Δ = unthunk (Δraw)
189
188
return (NoTangent (), NoTangent (), _back_star (x, y, Δ), _back_star (y, x, Δ))
@@ -197,7 +196,7 @@ _back_star(x::Complex{Bool}, y, Δ) = NoTangent() # e.g. for fun.(im.*x)
197
196
198
197
# This works, but not sure it improves any benchmarks. Needs corresponding scalar rule to avoid ambiguities.
199
198
function rrule (:: typeof (broadcasted), :: typeof (* ), x:: NumericOrBroadcast , y:: NumericOrBroadcast , zs:: NumericOrBroadcast... )
200
- _print ( " times" , 2 + length (zs))
199
+ @debug ( " broadcasting: times" , 2 + length (zs))
201
200
xy, back1 = rrule (broadcasted, * , x, y)
202
201
xyz, back2 = rrule (broadcasted, * , xy, zs... )
203
202
function bc_times3_back (dxyz)
@@ -209,7 +208,7 @@ function rrule(::typeof(broadcasted), ::typeof(*), x::NumericOrBroadcast, y::Num
209
208
end
210
209
211
210
function rrule (:: typeof (broadcasted), :: typeof (Base. literal_pow), :: typeof (^ ), x:: NumericOrBroadcast , :: Val{2} )
212
- _print ( " square" )
211
+ @debug ( " broadcasting: square" )
213
212
function bc_square_back (dy_raw)
214
213
dx = @thunk ProjectTo (x)(2 .* unthunk (dy_raw) .* conj .(x))
215
214
return (NoTangent (), NoTangent (), NoTangent (), dx, NoTangent ())
@@ -218,7 +217,7 @@ function rrule(::typeof(broadcasted), ::typeof(Base.literal_pow), ::typeof(^), x
218
217
end
219
218
220
219
function rrule (:: typeof (broadcasted), :: typeof (/ ), x:: NumericOrBroadcast , y:: Number )
221
- _print ( " divide" )
220
+ @debug ( " broadcasting: divide" )
222
221
# z = broadcast(/, x, y)
223
222
z = broadcasted (/ , x, y)
224
223
function bc_divide_back (dz_raw)
@@ -255,14 +254,14 @@ rrule(::typeof(broadcasted), ::typeof(identity), x::NumericOrBroadcast) = rrule(
255
254
rrule (:: typeof (broadcasted), :: typeof (identity), x:: Number ) = rrule (identity, x) |> _prepend_zero # ambiguity
256
255
257
256
function rrule (:: typeof (broadcasted), :: Type{T} , x:: NumericOrBroadcast ) where {T<: Number }
258
- _print ( " bc type" , T)
257
+ @debug ( " broadcasting: type" , T)
259
258
bc_type_back (dz) = (NoTangent (), NoTangent (), @thunk (unbroadcast (x, unthunk (dz))))
260
259
return broadcasted (T, x), bc_type_back
261
260
end
262
261
rrule (:: typeof (broadcasted), :: Type{T} , x:: Number ) where {T<: Number } = rrule (T, x) |> _prepend_zero
263
262
264
263
function rrule (:: typeof (broadcasted), :: typeof (float), x:: NumericOrBroadcast )
265
- _print ( " bc float" )
264
+ @debug ( " broadcasting: float" )
266
265
bc_float_back (dz) = (NoTangent (), NoTangent (), @thunk (unbroadcast (x, unthunk (dz))))
267
266
return broadcasted (float, x), bc_float_back
268
267
end
@@ -284,27 +283,27 @@ for conj in [:conj, :adjoint] # identical as we know eltype <: Number
284
283
end
285
284
286
285
function rrule (:: typeof (broadcasted), :: typeof (real), x:: NumericOrBroadcast )
287
- _print ( " real" )
286
+ @debug ( " broadcasting: real" )
288
287
bc_real_back (dz) = (NoTangent (), NoTangent (), @thunk (real (unthunk (dz))))
289
288
return broadcasted (real, x), bc_real_back
290
289
end
291
290
rrule (:: typeof (broadcasted), :: typeof (real), x:: Number ) = rrule (real, x) |> _prepend_zero
292
291
rrule (:: typeof (broadcasted), :: typeof (real), x:: AbstractArray{<:Real} ) = rrule (identity, x) |> _prepend_zero
293
292
294
293
function rrule (:: typeof (broadcasted), :: typeof (imag), x:: NumericOrBroadcast )
295
- _print ( " imag" )
294
+ @debug ( " broadcasting: imag" )
296
295
bc_imag_back (dz) = (NoTangent (), NoTangent (), @thunk (im .* real .(unthunk (dz))))
297
296
return broadcasted (imag, x), bc_imag_back
298
297
end
299
298
rrule (:: typeof (broadcasted), :: typeof (imag), x:: Number ) = rrule (imag, x) |> _prepend_zero
300
299
function rrule (:: typeof (broadcasted), :: typeof (imag), x:: AbstractArray{<:Real} )
301
- _print ( " imag(real)" )
300
+ @debug ( " broadcasting: imag(real)" )
302
301
bc_imag_back_2 (dz) = (NoTangent (), NoTangent (), ZeroTangent ())
303
302
return broadcasted (imag, x), bc_imag_back_2
304
303
end
305
304
306
305
function rrule (:: typeof (broadcasted), :: typeof (complex), x:: NumericOrBroadcast )
307
- _print ( " bc complex" )
306
+ @debug ( " broadcasting: complex" )
308
307
bc_complex_back (dz) = (NoTangent (), NoTangent (), @thunk (unbroadcast (x, unthunk (dz))))
309
308
return broadcasted (complex, x), bc_complex_back
310
309
end
0 commit comments