@@ -44,7 +44,8 @@ function split_bc_rule(f::F, args...) where {F}
44
44
# Trivial case
45
45
back_1 (_) = ntuple (Returns (ZeroTangent ()), length (args)+ 2 )
46
46
return f .(args... ), back_1
47
- elseif all (a -> a isa Numeric, args) && isconcretetype (Core. Compiler. _return_type (
47
+ # elseif all(a -> a isa Numeric, args) && isconcretetype(Core.Compiler._return_type(
48
+ elseif isconcretetype (Core. Compiler. _return_type (
48
49
derivatives_given_output, Tuple{T, F, map (eltype, args)... }))
49
50
# Fast path: just broadcast, and use x & y to find derivative.
50
51
ys = f .(args... )
@@ -79,6 +80,7 @@ splitcast(f, args...) = StructArrays.components(StructArray(Broadcast.instantiat
79
80
80
81
unbroadcast (f:: Function , x̄) = accum_sum (x̄)
81
82
unbroadcast (:: Val , _) = NoTangent ()
83
+ unbroadcast (x:: AbstractArray , x̄:: NoTangent ) = NoTangent ()
82
84
accum_sum (xs:: AbstractArray{<:NoTangent} ; dims = :) = NoTangent ()
83
85
84
86
#=
@@ -99,6 +101,7 @@ julia> @btime gradient(x -> sum(abs2, exp.(x)), $xs);
99
101
44.042 μs (32 allocations: 313.48 KiB) # slow path -- 3 copies, extra is closure?
100
102
61.167 μs (12 allocations: 703.41 KiB) # with `map` rule as before -- worse
101
103
104
+
102
105
# Composed function, Zygote struggles
103
106
104
107
julia> @btime Zygote.gradient(x -> sum(abs2, (identity∘cbrt).(x)), $xs);
@@ -116,6 +119,27 @@ julia> @btime gradient(x -> sum(abs2, identity.(cbrt.(x))), $xs);
116
119
75.041 μs (46 allocations: 470.11 KiB) # slow path -- 5 copies
117
120
135.541 μs (27 allocations: 1.30 MiB) # with `map` rule as before -- worse
118
121
122
+
123
+ # Lazy +,-,* for partial fusing
124
+
125
+ julia> @btime Zygote.gradient(x -> sum(abs2, exp.(2 .* x .- 100)), $xs);
126
+ 81.250 μs (21 allocations: 625.47 KiB) # special rules + dual numbers, 4 more copies than best
127
+
128
+ julia> @btime gradient(x -> sum(abs2, exp.(2 .* x .- 100)), $xs);
129
+ 57.166 μs (49 allocations: 470.22 KiB) # broadcast in *, -
130
+ 54.583 μs (46 allocations: 314.06 KiB) # broadcasted -- two less copies
131
+ 72.958 μs (26 allocations: 1016.38 KiB) # with `map` rule as before
132
+
133
+ julia> gradient((x,y) -> sum(abs2, exp.(2 .* x .+ y)), xs, (rand(10)'))
134
+ ERROR: MethodError: no method matching size(::Base.Broadcast.Broadcasted # hmm
135
+
136
+ julia> @btime gradient((x,y) -> sum(abs2, exp.(x .+ y)), $xs, $(rand(100)'));
137
+ 7.127 ms (75 allocations: 22.97 MiB) # after
138
+ 12.956 ms (57 allocations: 76.37 MiB) # before
139
+
140
+ ulia> @btime Zygote.gradient((x,y) -> sum(abs2, exp.(x .+ y)), $xs, $(rand(100)'));
141
+ 9.937 ms (48 allocations: 45.86 MiB)
142
+
119
143
=#
120
144
121
145
# The below is from Zygote: TODO : DO we want to do something better here?
133
157
134
158
trim (x, Δ) = reshape (Δ, ntuple (i -> size (Δ, i), Val (ndims (x))))
135
159
136
- unbroadcast (x:: AbstractArray , x̄) =
160
+ unbroadcast (x:: Union{ AbstractArray, Base.Broadcast.Broadcasted} , x̄) =
137
161
size (x) == size (x̄) ? x̄ :
138
162
length (x) == length (x̄) ? trim (x, x̄) :
139
163
trim (x, accum_sum (x̄, dims = ntuple (i -> size (x, i) == 1 ? i : ndims (x̄)+ 1 , Val (ndims (x̄)))))
@@ -146,12 +170,46 @@ unbroadcast(x::AbstractArray, x̄::Nothing) = NoTangent()
146
170
147
171
const Numeric = Union{Number, AbstractArray{<: Number , N} where N}
148
172
149
- function ChainRulesCore. rrule (:: typeof (broadcasted), :: typeof (+ ), xs:: Numeric... )
150
- broadcast (+ , xs... ), ȳ -> (NoTangent (), NoTangent (), map (x -> unbroadcast (x, unthunk (ȳ)), xs)... )
173
+ # function ChainRulesCore.rrule(::typeof(broadcasted), ::typeof(+), xs::Numeric...)
174
+ # broadcast(+, xs...), ȳ -> (NoTangent(), NoTangent(), map(x -> unbroadcast(x, unthunk(ȳ)), xs)...)
175
+ # end
176
+
177
+ # Replace Zygote-like fully split broadcasting with one fused over easy operations
178
+ (:: ∂⃖{1 })(:: typeof (broadcasted), :: typeof (+ ), args... ) = split_bc_plus (args... )
179
+ (:: ∂⃖{1 })(:: typeof (broadcasted), :: typeof (+ ), arg:: Array ) = split_bc_plus (arg) # ambiguity
180
+ function split_bc_plus (xs... ) where {F}
181
+ broadcasted (+ , xs... ), Δ -> let Δun = unthunk (Δ)
182
+ # println("+")
183
+ (NoTangent (), NoTangent (), map (x -> unbroadcast (x, Δun), xs)... )
184
+ end
151
185
end
186
+ Base. eltype (bc:: Broadcast.Broadcasted{<:Any, <:Any, typeof(+), <:Tuple} ) =
187
+ mapreduce (eltype, promote_type, bc. args) # needed to hit fast path
152
188
153
- ChainRulesCore. rrule (:: typeof (broadcasted), :: typeof (- ), x:: Numeric , y:: Numeric ) = x .- y,
154
- Δ -> let Δ= unthunk (Δ); (NoTangent (), NoTangent (), unbroadcast (x, Δ), - unbroadcast (y, Δ)); end
189
+ (:: ∂⃖{1 })(:: typeof (copy), bc:: Broadcast.Broadcasted ) = copy (bc), Δ -> (NoTangent (), Δ)
155
190
156
- ChainRulesCore. rrule (:: typeof (broadcasted), :: typeof (* ), x:: Numeric , y:: Numeric ) = x.* y,
157
- z̄ -> let z̄= unthunk (z̄); (NoTangent (), NoTangent (), unbroadcast (x, z̄ .* conj .(y)), unbroadcast (y, z̄ .* conj .(x))); end
191
+ # ChainRulesCore.rrule(::typeof(broadcasted), ::typeof(-), x::Numeric, y::Numeric) = x .- y,
192
+ # Δ -> let Δ=unthunk(Δ); (NoTangent(), NoTangent(), unbroadcast(x, Δ), -unbroadcast(y, Δ)); end
193
+
194
+ function (:: ∂⃖{1 })(:: typeof (broadcasted), :: typeof (- ), x, y)
195
+ broadcasted (- , x, y), Δ -> let Δun = unthunk (Δ)
196
+ # println("-")
197
+ (NoTangent (), NoTangent (), unbroadcast (x, Δun), - unbroadcast (y, Δun))
198
+ # Ideally you could fuse the - into unbroadcast, mapreduce() not sum, when y is a smaller array
199
+ end
200
+ end
201
+
202
+ # ChainRulesCore.rrule(::typeof(broadcasted), ::typeof(*), x::Numeric, y::Numeric) = x.*y,
203
+ # z̄ -> let z̄=unthunk(z̄); (NoTangent(), NoTangent(), unbroadcast(x, z̄ .* conj.(y)), unbroadcast(y, z̄ .* conj.(x))); end
204
+
205
+ using LinearAlgebra: dot
206
+
207
+ function (:: ∂⃖{1 })(:: typeof (broadcasted), :: typeof (* ), x, y)
208
+ broadcasted (* , x, y), Δ -> let Δun = unthunk (Δ)
209
+ # println("*")
210
+ dx = x isa Number ? dot (y, Δun) : unbroadcast (x, Δun .* conj .(y))
211
+ dy = y isa Number ? dot (x, Δun) : unbroadcast (y, Δun .* conj .(x))
212
+ # When x is an array but a smaller one, instead of dot you may be able to use mapreduce()
213
+ (NoTangent (), NoTangent (), dx, dy)
214
+ end
215
+ end
0 commit comments