Skip to content

Commit b484a14

Browse files
committed
lazier +,-,* rules
1 parent 745e4ee commit b484a14

File tree

1 file changed

+66
-8
lines changed

1 file changed

+66
-8
lines changed

src/stage1/broadcast.jl

Lines changed: 66 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,8 @@ function split_bc_rule(f::F, args...) where {F}
4444
# Trivial case
4545
back_1(_) = ntuple(Returns(ZeroTangent()), length(args)+2)
4646
return f.(args...), back_1
47-
elseif all(a -> a isa Numeric, args) && isconcretetype(Core.Compiler._return_type(
47+
# elseif all(a -> a isa Numeric, args) && isconcretetype(Core.Compiler._return_type(
48+
elseif isconcretetype(Core.Compiler._return_type(
4849
derivatives_given_output, Tuple{T, F, map(eltype, args)...}))
4950
# Fast path: just broadcast, and use x & y to find derivative.
5051
ys = f.(args...)
@@ -79,6 +80,7 @@ splitcast(f, args...) = StructArrays.components(StructArray(Broadcast.instantiat
7980

8081
unbroadcast(f::Function, x̄) = accum_sum(x̄)
8182
unbroadcast(::Val, _) = NoTangent()
83+
unbroadcast(x::AbstractArray, x̄::NoTangent) = NoTangent()
8284
accum_sum(xs::AbstractArray{<:NoTangent}; dims = :) = NoTangent()
8385

8486
#=
@@ -99,6 +101,7 @@ julia> @btime gradient(x -> sum(abs2, exp.(x)), $xs);
99101
44.042 μs (32 allocations: 313.48 KiB) # slow path -- 3 copies, extra is closure?
100102
61.167 μs (12 allocations: 703.41 KiB) # with `map` rule as before -- worse
101103
104+
102105
# Composed function, Zygote struggles
103106
104107
julia> @btime Zygote.gradient(x -> sum(abs2, (identity∘cbrt).(x)), $xs);
@@ -116,6 +119,27 @@ julia> @btime gradient(x -> sum(abs2, identity.(cbrt.(x))), $xs);
116119
75.041 μs (46 allocations: 470.11 KiB) # slow path -- 5 copies
117120
135.541 μs (27 allocations: 1.30 MiB) # with `map` rule as before -- worse
118121
122+
123+
# Lazy +,-,* for partial fusing
124+
125+
julia> @btime Zygote.gradient(x -> sum(abs2, exp.(2 .* x .- 100)), $xs);
126+
81.250 μs (21 allocations: 625.47 KiB) # special rules + dual numbers, 4 more copies than best
127+
128+
julia> @btime gradient(x -> sum(abs2, exp.(2 .* x .- 100)), $xs);
129+
57.166 μs (49 allocations: 470.22 KiB) # broadcast in *, -
130+
54.583 μs (46 allocations: 314.06 KiB) # broadcasted -- two less copies
131+
72.958 μs (26 allocations: 1016.38 KiB) # with `map` rule as before
132+
133+
julia> gradient((x,y) -> sum(abs2, exp.(2 .* x .+ y)), xs, (rand(10)'))
134+
ERROR: MethodError: no method matching size(::Base.Broadcast.Broadcasted # hmm
135+
136+
julia> @btime gradient((x,y) -> sum(abs2, exp.(x .+ y)), $xs, $(rand(100)'));
137+
7.127 ms (75 allocations: 22.97 MiB) # after
138+
12.956 ms (57 allocations: 76.37 MiB) # before
139+
140+
ulia> @btime Zygote.gradient((x,y) -> sum(abs2, exp.(x .+ y)), $xs, $(rand(100)'));
141+
9.937 ms (48 allocations: 45.86 MiB)
142+
119143
=#
120144

121145
# The below is from Zygote: TODO: DO we want to do something better here?
@@ -133,7 +157,7 @@ end
133157

134158
trim(x, Δ) = reshape(Δ, ntuple(i -> size(Δ, i), Val(ndims(x))))
135159

136-
unbroadcast(x::AbstractArray, x̄) =
160+
unbroadcast(x::Union{AbstractArray, Base.Broadcast.Broadcasted}, x̄) =
137161
size(x) == size(x̄) ?:
138162
length(x) == length(x̄) ? trim(x, x̄) :
139163
trim(x, accum_sum(x̄, dims = ntuple(i -> size(x, i) == 1 ? i : ndims(x̄)+1, Val(ndims(x̄)))))
@@ -146,12 +170,46 @@ unbroadcast(x::AbstractArray, x̄::Nothing) = NoTangent()
146170

147171
const Numeric = Union{Number, AbstractArray{<:Number, N} where N}
148172

149-
function ChainRulesCore.rrule(::typeof(broadcasted), ::typeof(+), xs::Numeric...)
150-
broadcast(+, xs...), ȳ -> (NoTangent(), NoTangent(), map(x -> unbroadcast(x, unthunk(ȳ)), xs)...)
173+
# function ChainRulesCore.rrule(::typeof(broadcasted), ::typeof(+), xs::Numeric...)
174+
# broadcast(+, xs...), ȳ -> (NoTangent(), NoTangent(), map(x -> unbroadcast(x, unthunk(ȳ)), xs)...)
175+
# end
176+
177+
# Replace Zygote-like fully split broadcasting with one fused over easy operations
178+
(::∂⃖{1})(::typeof(broadcasted), ::typeof(+), args...) = split_bc_plus(args...)
179+
(::∂⃖{1})(::typeof(broadcasted), ::typeof(+), arg::Array) = split_bc_plus(arg) # ambiguity
180+
function split_bc_plus(xs...) where {F}
181+
broadcasted(+, xs...), Δ -> let Δun = unthunk(Δ)
182+
# println("+")
183+
(NoTangent(), NoTangent(), map(x -> unbroadcast(x, Δun), xs)...)
184+
end
151185
end
186+
Base.eltype(bc::Broadcast.Broadcasted{<:Any, <:Any, typeof(+), <:Tuple}) =
187+
mapreduce(eltype, promote_type, bc.args) # needed to hit fast path
152188

153-
ChainRulesCore.rrule(::typeof(broadcasted), ::typeof(-), x::Numeric, y::Numeric) = x .- y,
154-
Δ -> let Δ=unthunk(Δ); (NoTangent(), NoTangent(), unbroadcast(x, Δ), -unbroadcast(y, Δ)); end
189+
(::∂⃖{1})(::typeof(copy), bc::Broadcast.Broadcasted) = copy(bc), Δ -> (NoTangent(), Δ)
155190

156-
ChainRulesCore.rrule(::typeof(broadcasted), ::typeof(*), x::Numeric, y::Numeric) = x.*y,
157-
-> let=unthunk(z̄); (NoTangent(), NoTangent(), unbroadcast(x, z̄ .* conj.(y)), unbroadcast(y, z̄ .* conj.(x))); end
191+
# ChainRulesCore.rrule(::typeof(broadcasted), ::typeof(-), x::Numeric, y::Numeric) = x .- y,
192+
# Δ -> let Δ=unthunk(Δ); (NoTangent(), NoTangent(), unbroadcast(x, Δ), -unbroadcast(y, Δ)); end
193+
194+
function (::∂⃖{1})(::typeof(broadcasted), ::typeof(-), x, y)
195+
broadcasted(-, x, y), Δ -> let Δun = unthunk(Δ)
196+
# println("-")
197+
(NoTangent(), NoTangent(), unbroadcast(x, Δun), -unbroadcast(y, Δun))
198+
# Ideally you could fuse the - into unbroadcast, mapreduce() not sum, when y is a smaller array
199+
end
200+
end
201+
202+
# ChainRulesCore.rrule(::typeof(broadcasted), ::typeof(*), x::Numeric, y::Numeric) = x.*y,
203+
# z̄ -> let z̄=unthunk(z̄); (NoTangent(), NoTangent(), unbroadcast(x, z̄ .* conj.(y)), unbroadcast(y, z̄ .* conj.(x))); end
204+
205+
using LinearAlgebra: dot
206+
207+
function (::∂⃖{1})(::typeof(broadcasted), ::typeof(*), x, y)
208+
broadcasted(*, x, y), Δ -> let Δun = unthunk(Δ)
209+
# println("*")
210+
dx = x isa Number ? dot(y, Δun) : unbroadcast(x, Δun .* conj.(y))
211+
dy = y isa Number ? dot(x, Δun) : unbroadcast(y, Δun .* conj.(x))
212+
# When x is an array but a smaller one, instead of dot you may be able to use mapreduce()
213+
(NoTangent(), NoTangent(), dx, dy)
214+
end
215+
end

0 commit comments

Comments
 (0)