Skip to content

Commit b6d279b

Browse files
committed
debug
1 parent 7e7d105 commit b6d279b

File tree

1 file changed

+19
-20
lines changed

1 file changed

+19
-20
lines changed

src/rulesets/Base/broadcast.jl

Lines changed: 19 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@ function rrule(::typeof(Broadcast.instantiate), bc::Broadcasted)
1313
return Broadcast.instantiate(bc), uninstantiate
1414
end
1515

16-
_print(args...) = printstyled("CR: ", join(args, " "), "\n", color=:magenta) # nothing #
1716

1817
#####
1918
##### Split broadcasting
@@ -31,7 +30,7 @@ function rrule(cfg::RCR, ::typeof(broadcasted), ::BroadcastStyle, f::F, args::Va
3130
T = Broadcast.combine_eltypes(f, args)
3231
if T === Bool # TODO use nondifftype here
3332
# 1: Trivial case: non-differentiable output, e.g. `x .> 0`
34-
_print("split_bc_trivial", f)
33+
@debug("split broadcasting trivial", f, T)
3534
bc_trivial_back(_) = (TRI_NO..., ntuple(Returns(ZeroTangent()), length(args))...)
3635
return f.(args...), bc_trivial_back
3736
elseif T <: Number && may_bc_derivatives(T, f, args...)
@@ -57,7 +56,7 @@ _eltype(x) = eltype(x) # ... but try harder to avoid `eltype(Broadcast.broadcas
5756
_eltype(bc::Broadcast.Broadcasted) = Broadcast.combine_eltypes(bc.f, bc.args)
5857

5958
function split_bc_derivatives(f::F, arg) where {F}
60-
_print("split_bc_derivative", f)
59+
@debug("split broadcasting derivative", f)
6160
ys = f.(arg)
6261
function bc_one_back(dys) # For f.(x) we do not need StructArrays / unzip at all
6362
delta = broadcast(unthunk(dys), ys, arg) do dy, y, a
@@ -70,7 +69,7 @@ function split_bc_derivatives(f::F, arg) where {F}
7069
return ys, bc_one_back
7170
end
7271
function split_bc_derivatives(f::F, args::Vararg{Any,N}) where {F,N}
73-
_print("split_bc_derivatives", f, N)
72+
@debug("split broadcasting derivatives", f, N)
7473
ys = f.(args...)
7574
function bc_many_back(dys)
7675
deltas = tuplecast(unthunk(dys), ys, args...) do dy, y, as...
@@ -105,7 +104,7 @@ end
105104
split_bc_forwards(cfg::RuleConfig{>:HasForwardsMode}, f::F, arg) where {F} = split_bc_inner(frule_via_ad, cfg, f, arg)
106105
split_bc_forwards(cfg::RuleConfig, f::F, arg) where {F} = split_bc_inner(frule, cfg, f, arg)
107106
function split_bc_inner(frule_fun::R, cfg::RuleConfig, f::F, arg) where {R,F}
108-
_print("split_bc_forwards", frule_fun, f)
107+
@debug("split broadcasting forwards", frule_fun, f)
109108
ys, ydots = tuplecast(arg) do a
110109
frule_fun(cfg, (NoTangent(), one(a)), f, a)
111110
end
@@ -124,7 +123,7 @@ end
124123
# can change the number of calls, don't bother to try to reverse the iteration.
125124

126125
function split_bc_pullbacks(cfg::RCR, f::F, args::Vararg{Any,N}) where {F,N}
127-
_print("split_bc_generic", f, N)
126+
@debug("split broadcasting generic", f, N)
128127
ys3, backs = tuplecast(args...) do a...
129128
rrule_via_ad(cfg, f, a...)
130129
end
@@ -142,7 +141,7 @@ end
142141

143142
# Don't run broadcasting on scalars
144143
function rrule(cfg::RCR, ::typeof(broadcasted), ::BroadcastStyle, f::F, args::Number...) where {F}
145-
_print("split_bc_scalar", f)
144+
@debug("split broadcasting scalar", f)
146145
z, back = rrule_via_ad(cfg, f, args...)
147146
return z, dz -> (NoTangent(), NoTangent(), back(dz)...)
148147
end
@@ -159,7 +158,7 @@ const NumericOrBroadcast = Union{Number, AbstractArray{<:Number}, NTuple{<:Any,N
159158
##### Arithmetic: +, -, *, ^2, /
160159

161160
function rrule(::typeof(broadcasted), ::typeof(+), xs::NumericOrBroadcast...)
162-
_print("plus", length(xs))
161+
@debug("broadcasting: plus", length(xs))
163162
function bc_plus_back(dy_raw)
164163
dy = unthunk(dy_raw)
165164
return (NoTangent(), NoTangent(), map(x -> unbroadcast(x, dy), xs)...) # no copies, this may return dx2 === dx3
@@ -168,7 +167,7 @@ function rrule(::typeof(broadcasted), ::typeof(+), xs::NumericOrBroadcast...)
168167
end
169168

170169
function rrule(::typeof(broadcasted), ::typeof(-), x::NumericOrBroadcast, y::NumericOrBroadcast)
171-
_print("minus 2")
170+
@debug("broadcasting: minus 2")
172171
function bc_minus_back(dz_raw)
173172
dz = unthunk(dz_raw)
174173
return (NoTangent(), NoTangent(), @thunk(unbroadcast(x, dz)), @thunk(-unbroadcast(y, dz)))
@@ -177,13 +176,13 @@ function rrule(::typeof(broadcasted), ::typeof(-), x::NumericOrBroadcast, y::Num
177176
end
178177

179178
function rrule(::typeof(broadcasted), ::typeof(-), x::NumericOrBroadcast)
180-
_print("minus 1")
179+
@debug("broadcasting: minus 1")
181180
bc_minus_back(dy) = (NoTangent(), NoTangent(), @thunk -unthunk(dy))
182181
return broadcasted(-, x), bc_minus_back
183182
end
184183

185184
function rrule(::typeof(broadcasted), ::typeof(*), x::NumericOrBroadcast, y::NumericOrBroadcast)
186-
_print("times")
185+
@debug("broadcasting: times")
187186
function bc_times_back(Δraw)
188187
Δ = unthunk(Δraw)
189188
return (NoTangent(), NoTangent(), _back_star(x, y, Δ), _back_star(y, x, Δ))
@@ -197,7 +196,7 @@ _back_star(x::Complex{Bool}, y, Δ) = NoTangent() # e.g. for fun.(im.*x)
197196

198197
# This works, but not sure it improves any benchmarks. Needs corresponding scalar rule to avoid ambiguities.
199198
function rrule(::typeof(broadcasted), ::typeof(*), x::NumericOrBroadcast, y::NumericOrBroadcast, zs::NumericOrBroadcast...)
200-
_print("times", 2 + length(zs))
199+
@debug("broadcasting: times", 2 + length(zs))
201200
xy, back1 = rrule(broadcasted, *, x, y)
202201
xyz, back2 = rrule(broadcasted, *, xy, zs...)
203202
function bc_times3_back(dxyz)
@@ -209,7 +208,7 @@ function rrule(::typeof(broadcasted), ::typeof(*), x::NumericOrBroadcast, y::Num
209208
end
210209

211210
function rrule(::typeof(broadcasted), ::typeof(Base.literal_pow), ::typeof(^), x::NumericOrBroadcast, ::Val{2})
212-
_print("square")
211+
@debug("broadcasting: square")
213212
function bc_square_back(dy_raw)
214213
dx = @thunk ProjectTo(x)(2 .* unthunk(dy_raw) .* conj.(x))
215214
return (NoTangent(), NoTangent(), NoTangent(), dx, NoTangent())
@@ -218,7 +217,7 @@ function rrule(::typeof(broadcasted), ::typeof(Base.literal_pow), ::typeof(^), x
218217
end
219218

220219
function rrule(::typeof(broadcasted), ::typeof(/), x::NumericOrBroadcast, y::Number)
221-
_print("divide")
220+
@debug("broadcasting: divide")
222221
# z = broadcast(/, x, y)
223222
z = broadcasted(/, x, y)
224223
function bc_divide_back(dz_raw)
@@ -255,14 +254,14 @@ rrule(::typeof(broadcasted), ::typeof(identity), x::NumericOrBroadcast) = rrule(
255254
rrule(::typeof(broadcasted), ::typeof(identity), x::Number) = rrule(identity, x) |> _prepend_zero # ambiguity
256255

257256
function rrule(::typeof(broadcasted), ::Type{T}, x::NumericOrBroadcast) where {T<:Number}
258-
_print("bc type", T)
257+
@debug("broadcasting: type", T)
259258
bc_type_back(dz) = (NoTangent(), NoTangent(), @thunk(unbroadcast(x, unthunk(dz))))
260259
return broadcasted(T, x), bc_type_back
261260
end
262261
rrule(::typeof(broadcasted), ::Type{T}, x::Number) where {T<:Number} = rrule(T, x) |> _prepend_zero
263262

264263
function rrule(::typeof(broadcasted), ::typeof(float), x::NumericOrBroadcast)
265-
_print("bc float")
264+
@debug("broadcasting: float")
266265
bc_float_back(dz) = (NoTangent(), NoTangent(), @thunk(unbroadcast(x, unthunk(dz))))
267266
return broadcasted(float, x), bc_float_back
268267
end
@@ -284,27 +283,27 @@ for conj in [:conj, :adjoint] # identical as we know eltype <: Number
284283
end
285284

286285
function rrule(::typeof(broadcasted), ::typeof(real), x::NumericOrBroadcast)
287-
_print("real")
286+
@debug("broadcasting: real")
288287
bc_real_back(dz) = (NoTangent(), NoTangent(), @thunk(real(unthunk(dz))))
289288
return broadcasted(real, x), bc_real_back
290289
end
291290
rrule(::typeof(broadcasted), ::typeof(real), x::Number) = rrule(real, x) |> _prepend_zero
292291
rrule(::typeof(broadcasted), ::typeof(real), x::AbstractArray{<:Real}) = rrule(identity, x) |> _prepend_zero
293292

294293
function rrule(::typeof(broadcasted), ::typeof(imag), x::NumericOrBroadcast)
295-
_print("imag")
294+
@debug("broadcasting: imag")
296295
bc_imag_back(dz) = (NoTangent(), NoTangent(), @thunk(im .* real.(unthunk(dz))))
297296
return broadcasted(imag, x), bc_imag_back
298297
end
299298
rrule(::typeof(broadcasted), ::typeof(imag), x::Number) = rrule(imag, x) |> _prepend_zero
300299
function rrule(::typeof(broadcasted), ::typeof(imag), x::AbstractArray{<:Real})
301-
_print("imag(real)")
300+
@debug("broadcasting: imag(real)")
302301
bc_imag_back_2(dz) = (NoTangent(), NoTangent(), ZeroTangent())
303302
return broadcasted(imag, x), bc_imag_back_2
304303
end
305304

306305
function rrule(::typeof(broadcasted), ::typeof(complex), x::NumericOrBroadcast)
307-
_print("bc complex")
306+
@debug("broadcasting: complex")
308307
bc_complex_back(dz) = (NoTangent(), NoTangent(), @thunk(unbroadcast(x, unthunk(dz))))
309308
return broadcasted(complex, x), bc_complex_back
310309
end

0 commit comments

Comments
 (0)