Skip to content

Commit 569ae32

Browse files
committed
tidy, add more lazy cases
1 parent 7eb9be8 commit 569ae32

File tree

2 files changed

+79
-22
lines changed

2 files changed

+79
-22
lines changed

src/stage1/broadcast.jl

Lines changed: 75 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -29,28 +29,30 @@ function (∂ₙ::∂☆{N})(zc::ZeroBundle{N, typeof(copy)},
2929
return r
3030
end
3131

32-
_print(s) = nothing
33-
# _print(s) = printstyled(s, "\n"; color=:magenta)
32+
# Reverse mode broadcast rules
33+
34+
using ChainRulesCore: derivatives_given_output
35+
36+
# _print(s) = nothing
37+
_print(s) = printstyled(s, "\n"; color=:magenta)
3438

3539
# Broadcast over one element is just map
3640
function (∂⃖ₙ::∂⃖{N})(::typeof(broadcasted), f, a::Array) where {N}
3741
_print("path 0")
3842
∂⃖ₙ(map, f, a)
3943
end
4044

41-
using ChainRulesCore: derivatives_given_output
42-
4345
(::∂⃖{1})(::typeof(broadcasted), f, args...) = split_bc_rule(f, args...)
4446
(::∂⃖{1})(::typeof(broadcasted), f, arg::Array) = split_bc_rule(f, arg) # ambiguity
4547
function split_bc_rule(f::F, args...) where {F}
4648
T = Broadcast.combine_eltypes(f, args)
47-
if T == Bool
48-
# Trivial case
49+
= Core.Compiler._return_type(derivatives_given_output, Tuple{T, F, map(eltype, args)...})
50+
if eltype(T) == Bool
51+
# Trivial case: non-differentiable output
4952
_print("path 1")
5053
back_1(_) = ntuple(Returns(ZeroTangent()), length(args)+2)
5154
return f.(args...), back_1
52-
elseif isconcretetype(Core.Compiler._return_type(
53-
derivatives_given_output, Tuple{T, F, map(eltype, args)...}))
55+
elseif T <: Number && isconcretetype(TΔ)
5456
# Fast path: just broadcast, and use x & y to find derivative.
5557
ys = f.(args...)
5658
_print("path 2")
@@ -65,8 +67,9 @@ function split_bc_rule(f::F, args...) where {F}
6567
return ys, back_2
6668
else
6769
# Slow path: collect all the pullbacks & apply them later.
70+
# Since broadcast makes no guarantee about order, this does not bother to try to reverse it.
6871
_print("path 3")
69-
ys, backs = splitcast(rrule_via_ad, DiffractorRuleConfig(), f, args...)
72+
ys, backs = splitcast(∂⃖{1}(), f, args...)
7073
function back_3(dys)
7174
deltas = splitmap(backs, unthunk(dys)) do back, dy
7275
map(unthunk, back(dy))
@@ -78,8 +81,11 @@ function split_bc_rule(f::F, args...) where {F}
7881
end
7982
end
8083

84+
# This uses "mulltimap"-like constructs:
85+
8186
using StructArrays
82-
splitmap(f, args...) = StructArrays.components(StructArray(Iterators.map(f, args...))) # warning: splitmap(identity, [1,2,3,4]) === NamedTuple()
87+
splitmap(f, args...) = StructArrays.components(StructArray(Iterators.map(f, args...)))
88+
# warning: splitmap(identity, [1,2,3,4]) === NamedTuple()
8389
splitcast(f, args...) = StructArrays.components(StructArray(Broadcast.instantiate(Broadcast.broadcasted(f, args...))))
8490

8591
# For certain cheap operations we can easily allow fused broadcast:
@@ -107,7 +113,7 @@ end
107113

108114
using LinearAlgebra: dot
109115

110-
function (::∂⃖{1})(::typeof(broadcasted), ::typeof(*), x, y)
116+
function (::∂⃖{1})(::typeof(broadcasted), ::typeof(*), x, y) # should this be vararg, or will laziness handle it?
111117
broadcasted(*, x, y), Δ -> let Δun = unthunk(Δ)
112118
_print("broadcast *")
113119
dx = eltype(x)==Bool ? NoTangent() : x isa Number ? dot(y, Δun) : unbroadcast(x, Δun .* conj.(y))
@@ -117,41 +123,88 @@ function (::∂⃖{1})(::typeof(broadcasted), ::typeof(*), x, y)
117123
(NoTangent(), NoTangent(), dx, dy)
118124
end
119125
end
126+
# Alternative to `x isa Number` etc above... but not quite right!
127+
# (::∂⃖{1})(::typeof(broadcasted), ::typeof(*), x, y::Number) = rrule_via_ad(DiffractorRuleConfig(), *, x, y)
128+
129+
function (::∂⃖{1})(::typeof(broadcasted), ::typeof(Base.literal_pow), ::typeof(^), x, ::Val{2})
130+
_print("broadcast ^2")
131+
broadcasted(*, x, x), Δ -> begin
132+
dx = unbroadcast(x, 2 .* Δ .* conj.(x))
133+
(NoTangent(), NoTangent(), NoTangent(), dx, NoTangent())
134+
end
135+
end
136+
function (::∂⃖{1})(::typeof(broadcasted), ::typeof(Base.literal_pow), ::typeof(^), x::Number, ::Val{2})
137+
_print("simple ^2")
138+
x^2, Δ -> (NoTangent(), NoTangent(), NoTangent(), 2 * Δ * conj(x), NoTangent())
139+
end
140+
141+
# function (::∂⃖{1})(::typeof(broadcasted), ::typeof(/), x, y) # not obvious whether this is better than automatic
142+
# broadcasted(/, x, y), Δ -> let Δun = unthunk(Δ)
143+
# _print("broadcast /")
144+
# dx = unbroadcast(x, Δ ./ conj.(y))
145+
# dy = unbroadcast(y, .-Δ .* conj.(res ./ y))
146+
# (NoTangent(), NoTangent(), dx, dy)
147+
# end
148+
# end
149+
function (::∂⃖{1})(::typeof(broadcasted), ::typeof(/), x, y::Number)
150+
_print("simple /")
151+
z, back = ∂⃖{1}()(/, x, y)
152+
z, Δ -> begin
153+
_, dx, dy = back(Δ)
154+
(NoTangent(), NoTangent(), dx, dy) # maybe there should be a funciton for this? Use for conj, identity too
155+
end
156+
end
120157

121158
(::∂⃖{1})(::typeof(broadcasted), ::typeof(conj), x) =
122159
broadcasted(conj, x), Δ -> (NoTangent(), conj(unthunk(Δ)))
123160
(::∂⃖{1})(::typeof(broadcasted), ::typeof(conj), x::AbstractArray{Real}) =
124161
x, Δ -> (NoTangent(), Δ)
125162

163+
(::∂⃖{1})(::typeof(broadcasted), ::typeof(identity), x) =
164+
x, Δ -> (NoTangent(), Δ)
165+
166+
# All broadcasts use `unbroadcast` to reduce to correct shape:
167+
126168
function unbroadcast(x::Base.AbstractArrayOrBroadcasted, dx)
127169
N = ndims(dx)
128170
if length(x) == length(dx)
129171
ProjectTo(x)(dx) # handles trivial reshapes, offsets, structured matrices, row vectors
130172
else
131-
# This is an awful hack to get type-stable `dims`
132-
dims = ntuple(d -> get(size(x), d, 1) == 1 ? d : N+1, N)
173+
dims = ntuple(d -> get(size(x), d, 1) == 1 ? d : N+1, N) # awful hack to get type-stable `dims`
133174
ProjectTo(x)(sum(dx; dims))
134175
end
135176
end
136177
unbroadcast(x::Base.AbstractArrayOrBroadcasted, dx::NoTangent) = NoTangent()
137178

179+
unbroadcast(x::T, dx) where {T<:Tuple{Any}} = ProjectTo(x)(Tangent{T}(sum(dx)))
180+
function unbroadcast(x::T, dx) where {T<:Tuple{Vararg{Any,N}}} where {N}
181+
_print("unbroadcast tuple")
182+
val = if length(x) == length(dx)
183+
dx
184+
else
185+
sum(dx; dims=2:ndims(dx))
186+
end
187+
ProjectTo(x)(NTuple{length(x)}(val)) # Tangent
188+
end
189+
190+
unbroadcast(f::Function, df) = sum(df)
138191
unbroadcast(x::Number, dx) = ProjectTo(x)(sum(dx))
139-
unbroadcast(f::Function, df) = ProjectTo(x)(sum(df))
140192
unbroadcast(x::Base.RefValue, dx) = ProjectTo(x)(Ref(sum(dx)))
141193

142194
unbroadcast(::Bool, dx) = NoTangent()
143195
unbroadcast(::AbstractArray{Bool}, dx) = NoTangent()
144196
unbroadcast(::AbstractArray{Bool}, ::NoTangent) = NoTangent() # ambiguity
145197
unbroadcast(::Val, dx) = NoTangent()
146-
# Maybe more non-diff types? Some fallback?
147198

148-
unbroadcast(x::T, dx) where {T<:Tuple{Any}} = ProjectTo(x)(Tangent{T}(sum(dx)))
149-
function unbroadcast(x::T, dx) where {T<:Tuple{Vararg{Any,N}}} where {N}
150-
_print("unbroadcast tuple")
151-
val = if length(x) == length(dx)
152-
dx
199+
function unbroadcast(x, dx)
200+
p = ProjectTo(x)
201+
if dx isa AbstractZero || p isa ProjectTo{<:AbstractZero}
202+
return NoTangent()
203+
end
204+
b = Broadcast.broadcastable(x)
205+
if b isa Ref # then x is scalar under broadcast
206+
return p(sum(dx))
153207
else
154-
sum(dx; dims=2:ndims(dx))
208+
error("don't know how to handle broadcast gradient for x::$(typeof(x))")
155209
end
156-
ProjectTo(x)(NTuple{length(x)}(val)) # Tangent
157210
end

test/runtests.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -216,6 +216,7 @@ z45, delta45 = frule_via_ad(DiffractorRuleConfig(), (0,1), x -> log(exp(x)), 2)
216216

217217
# Broadcasting
218218
@test gradient(x -> sum(x ./ x), [1,2,3]) == ([0,0,0],) # derivatives_given_output
219+
@test gradient(x -> sum(sqrt.(atan.(x, x'))), [1,2,3])[1] [0.2338, -0.0177, -0.0661] atol=1e-3
219220
@test gradient(x -> sum(exp.(log.(x))), [1,2,3]) == ([1,1,1],)
220221

221222
@test_broken gradient(x -> sum((explog).(x)), [1,2,3]) == ([1,1,1],) # stores pullback
@@ -229,6 +230,9 @@ exp_log(x) = exp(log(x))
229230
@test gradient(x -> sum(sum, Ref(x) ./ x), [1,2,3])[1] [-4.1666, 0.3333, 1.1666] atol=1e-3
230231
@test gradient(x -> sum(sum, (x,) ./ x), [1,2,3])[1] [-4.1666, 0.3333, 1.1666] atol=1e-3
231232

233+
@test unthunk.(gradient(x -> sum(x ./ 4), [1,2,3])) == ([0.25, 0.25, 0.25],)
234+
@test gradient(x -> sum([1,2,3] ./ x), 4) == (-0.375,)
235+
232236
@test gradient(x -> sum(x .> 2), [1,2,3]) == (ZeroTangent(),) # Bool output
233237
@test gradient(x -> sum(1 .+ iseven.(x)), [1,2,3]) == (ZeroTangent(),)
234238
@test gradient((x,y) -> sum(x .== y), [1,2,3], [1 2 3]) == (ZeroTangent(), ZeroTangent())

0 commit comments

Comments
 (0)