@@ -29,8 +29,12 @@ function (∂ₙ::∂☆{N})(zc::ZeroBundle{N, typeof(copy)},
29
29
return r
30
30
end
31
31
32
+ _print (s) = nothing
33
+ # _print(s) = printstyled(s, "\n"; color=:magenta)
34
+
32
35
# Broadcast over one element is just map
33
36
function (∂⃖ₙ: :∂⃖ {N})(:: typeof (broadcasted), f, a:: Array ) where {N}
37
+ _print (" path 0" )
34
38
∂⃖ₙ (map, f, a)
35
39
end
36
40
@@ -40,16 +44,16 @@ using ChainRulesCore: derivatives_given_output
40
44
(:: ∂⃖{1 })(:: typeof (broadcasted), f, arg:: Array ) = split_bc_rule (f, arg) # ambiguity
41
45
function split_bc_rule (f:: F , args... ) where {F}
42
46
T = Broadcast. combine_eltypes (f, args)
43
- if T == Bool && Base . issingletontype (F)
47
+ if T == Bool
44
48
# Trivial case
49
+ _print (" path 1" )
45
50
back_1 (_) = ntuple (Returns (ZeroTangent ()), length (args)+ 2 )
46
51
return f .(args... ), back_1
47
- # elseif all(a -> a isa Numeric, args) && isconcretetype(Core.Compiler._return_type(
48
52
elseif isconcretetype (Core. Compiler. _return_type (
49
53
derivatives_given_output, Tuple{T, F, map (eltype, args)... }))
50
54
# Fast path: just broadcast, and use x & y to find derivative.
51
55
ys = f .(args... )
52
- # println(" 2")
56
+ _print ( " path 2" )
53
57
function back_2 (dys)
54
58
deltas = splitcast (unthunk (dys), ys, args... ) do dy, y, as...
55
59
das = only (derivatives_given_output (y, f, as... ))
@@ -61,7 +65,7 @@ function split_bc_rule(f::F, args...) where {F}
61
65
return ys, back_2
62
66
else
63
67
# Slow path: collect all the pullbacks & apply them later.
64
- # println(" 3")
68
+ _print ( " path 3" )
65
69
ys, backs = splitcast (rrule_via_ad, DiffractorRuleConfig (), f, args... )
66
70
function back_3 (dys)
67
71
deltas = splitmap (backs, unthunk (dys)) do back, dy
@@ -78,108 +82,13 @@ using StructArrays
78
82
splitmap (f, args... ) = StructArrays. components (StructArray (Iterators. map (f, args... ))) # warning: splitmap(identity, [1,2,3,4]) === NamedTuple()
79
83
splitcast (f, args... ) = StructArrays. components (StructArray (Broadcast. instantiate (Broadcast. broadcasted (f, args... ))))
80
84
81
- unbroadcast (f:: Function , x̄) = accum_sum (x̄)
82
- unbroadcast (:: Val , _) = NoTangent ()
83
- unbroadcast (x:: AbstractArray , x̄:: NoTangent ) = NoTangent ()
84
- accum_sum (xs:: AbstractArray{<:NoTangent} ; dims = :) = NoTangent ()
85
-
86
- #=
87
-
88
- julia> xs = randn(10_000);
89
- julia> @btime Zygote.gradient(x -> sum(abs2, x), $xs)
90
- 4.744 μs (2 allocations: 78.17 KiB)
91
- julia> @btime Diffractor.unthunk.(gradient(x -> sum(abs2, x), $xs));
92
- 3.307 μs (2 allocations: 78.17 KiB)
93
-
94
- # Simple function
95
-
96
- julia> @btime Zygote.gradient(x -> sum(abs2, exp.(x)), $xs);
97
- 72.541 μs (29 allocations: 391.47 KiB) # with dual numbers -- like 4 copies
98
-
99
- julia> @btime gradient(x -> sum(abs2, exp.(x)), $xs);
100
- 45.875 μs (36 allocations: 235.47 KiB) # fast path -- one copy forward, one back
101
- 44.042 μs (32 allocations: 313.48 KiB) # slow path -- 3 copies, extra is closure?
102
- 61.167 μs (12 allocations: 703.41 KiB) # with `map` rule as before -- worse
103
-
104
-
105
- # Composed function, Zygote struggles
106
-
107
- julia> @btime Zygote.gradient(x -> sum(abs2, (identity∘cbrt).(x)), $xs);
108
- 97.167 μs (29 allocations: 391.61 KiB) # with dual numbers (Zygote master)
109
- 93.238 ms (849567 allocations: 19.22 MiB) # without, thus Zygote.pullback
110
-
111
- julia> @btime gradient(x -> sum(abs2, (identity∘cbrt).(x)), $xs);
112
- 55.290 ms (830060 allocations: 49.75 MiB) # slow path
113
- 14.747 ms (240043 allocations: 7.25 MiB) # with `map` rule as before -- better!
114
-
115
- # Compare unfused
116
-
117
- julia> @btime gradient(x -> sum(abs2, identity.(cbrt.(x))), $xs);
118
- 69.458 μs (50 allocations: 392.09 KiB) # fast path -- two copies forward, two back
119
- 75.041 μs (46 allocations: 470.11 KiB) # slow path -- 5 copies
120
- 135.541 μs (27 allocations: 1.30 MiB) # with `map` rule as before -- worse
121
-
122
-
123
- # Lazy +,-,* for partial fusing
124
-
125
- julia> @btime Zygote.gradient(x -> sum(abs2, exp.(2 .* x .- 100)), $xs);
126
- 81.250 μs (21 allocations: 625.47 KiB) # special rules + dual numbers, 4 more copies than best
127
-
128
- julia> @btime gradient(x -> sum(abs2, exp.(2 .* x .- 100)), $xs);
129
- 57.166 μs (49 allocations: 470.22 KiB) # broadcast in *, -
130
- 54.583 μs (46 allocations: 314.06 KiB) # broadcasted -- two less copies
131
- 72.958 μs (26 allocations: 1016.38 KiB) # with `map` rule as before
132
-
133
- julia> gradient((x,y) -> sum(abs2, exp.(2 .* x .+ y)), xs, (rand(10)'))
134
- ERROR: MethodError: no method matching size(::Base.Broadcast.Broadcasted # hmm
135
-
136
- julia> @btime gradient((x,y) -> sum(abs2, exp.(x .+ y)), $xs, $(rand(100)'));
137
- 7.127 ms (75 allocations: 22.97 MiB) # after
138
- 12.956 ms (57 allocations: 76.37 MiB) # before
139
-
140
- ulia> @btime Zygote.gradient((x,y) -> sum(abs2, exp.(x .+ y)), $xs, $(rand(100)'));
141
- 9.937 ms (48 allocations: 45.86 MiB)
142
-
143
- =#
85
+ # For certain cheap operations we can easily allow fused broadcast:
144
86
145
- # The below is from Zygote: TODO : DO we want to do something better here?
146
-
147
- accum_sum (xs:: Nothing ; dims = :) = NoTangent ()
148
- accum_sum (xs:: AbstractArray{Nothing} ; dims = :) = NoTangent ()
149
- accum_sum (xs:: AbstractArray{<:Number} ; dims = :) = sum (xs, dims = dims)
150
- accum_sum (xs:: AbstractArray{<:AbstractArray{<:Number}} ; dims = :) = sum (xs, dims = dims)
151
- accum_sum (xs:: Number ; dims = :) = xs
152
-
153
- # https://github.com/FluxML/Zygote.jl/issues/594
154
- function Base. reducedim_init (:: typeof (identity), :: typeof (accum), A:: AbstractArray , region)
155
- Base. reducedim_initarray (A, region, NoTangent (), Union{Nothing,eltype (A)})
156
- end
157
-
158
- trim (x, Δ) = reshape (Δ, ntuple (i -> size (Δ, i), Val (ndims (x))))
159
-
160
- unbroadcast (x:: Union{AbstractArray, Base.Broadcast.Broadcasted} , x̄) =
161
- size (x) == size (x̄) ? x̄ :
162
- length (x) == length (x̄) ? trim (x, x̄) :
163
- trim (x, accum_sum (x̄, dims = ntuple (i -> size (x, i) == 1 ? i : ndims (x̄)+ 1 , Val (ndims (x̄)))))
164
-
165
- unbroadcast (x:: Number , x̄) = accum_sum (x̄)
166
- unbroadcast (x:: Tuple{<:Any} , x̄) = (accum_sum (x̄),)
167
- unbroadcast (x:: Base.RefValue , x̄) = (x= accum_sum (x̄),)
168
-
169
- unbroadcast (x:: AbstractArray , x̄:: Nothing ) = NoTangent ()
170
-
171
- const Numeric = Union{Number, AbstractArray{<: Number , N} where N}
172
-
173
- # function ChainRulesCore.rrule(::typeof(broadcasted), ::typeof(+), xs::Numeric...)
174
- # broadcast(+, xs...), ȳ -> (NoTangent(), NoTangent(), map(x -> unbroadcast(x, unthunk(ȳ)), xs)...)
175
- # end
176
-
177
- # Replace Zygote-like fully split broadcasting with one fused over easy operations
178
87
(:: ∂⃖{1 })(:: typeof (broadcasted), :: typeof (+ ), args... ) = split_bc_plus (args... )
179
88
(:: ∂⃖{1 })(:: typeof (broadcasted), :: typeof (+ ), arg:: Array ) = split_bc_plus (arg) # ambiguity
180
89
function split_bc_plus (xs... ) where {F}
181
90
broadcasted (+ , xs... ), Δ -> let Δun = unthunk (Δ)
182
- # println(" +")
91
+ _print ( " broadcast +" )
183
92
(NoTangent (), NoTangent (), map (x -> unbroadcast (x, Δun), xs)... )
184
93
end
185
94
end
@@ -188,28 +97,61 @@ Base.eltype(bc::Broadcast.Broadcasted{<:Any, <:Any, typeof(+), <:Tuple}) =
188
97
189
98
(:: ∂⃖{1 })(:: typeof (copy), bc:: Broadcast.Broadcasted ) = copy (bc), Δ -> (NoTangent (), Δ)
190
99
191
- # ChainRulesCore.rrule(::typeof(broadcasted), ::typeof(-), x::Numeric, y::Numeric) = x .- y,
192
- # Δ -> let Δ=unthunk(Δ); (NoTangent(), NoTangent(), unbroadcast(x, Δ), -unbroadcast(y, Δ)); end
193
-
194
100
function (:: ∂⃖{1 })(:: typeof (broadcasted), :: typeof (- ), x, y)
195
101
broadcasted (- , x, y), Δ -> let Δun = unthunk (Δ)
196
- # println(" -")
102
+ _print ( " broadcast -" )
197
103
(NoTangent (), NoTangent (), unbroadcast (x, Δun), - unbroadcast (y, Δun))
198
104
# Ideally you could fuse the - into unbroadcast, mapreduce() not sum, when y is a smaller array
199
105
end
200
106
end
201
107
202
- # ChainRulesCore.rrule(::typeof(broadcasted), ::typeof(*), x::Numeric, y::Numeric) = x.*y,
203
- # z̄ -> let z̄=unthunk(z̄); (NoTangent(), NoTangent(), unbroadcast(x, z̄ .* conj.(y)), unbroadcast(y, z̄ .* conj.(x))); end
204
-
205
108
using LinearAlgebra: dot
206
109
207
110
function (:: ∂⃖{1 })(:: typeof (broadcasted), :: typeof (* ), x, y)
208
111
broadcasted (* , x, y), Δ -> let Δun = unthunk (Δ)
209
- # println(" *")
210
- dx = x isa Number ? dot (y, Δun) : unbroadcast (x, Δun .* conj .(y))
211
- dy = y isa Number ? dot (x, Δun) : unbroadcast (y, Δun .* conj .(x))
112
+ _print ( " broadcast *" )
113
+ dx = eltype (x) == Bool ? NoTangent () : x isa Number ? dot (y, Δun) : unbroadcast (x, Δun .* conj .(y))
114
+ dy = eltype (y) == Bool ? NoTangent () : y isa Number ? dot (x, Δun) : unbroadcast (y, Δun .* conj .(x))
212
115
# When x is an array but a smaller one, instead of dot you may be able to use mapreduce()
116
+ # Will things like this work? Ref([1,2]) .* [1,2,3]
213
117
(NoTangent (), NoTangent (), dx, dy)
214
118
end
215
119
end
120
+
121
+ (:: ∂⃖{1 })(:: typeof (broadcasted), :: typeof (conj), x) =
122
+ broadcasted (conj, x), Δ -> (NoTangent (), conj (unthunk (Δ)))
123
+ (:: ∂⃖{1 })(:: typeof (broadcasted), :: typeof (conj), x:: AbstractArray{Real} ) =
124
+ x, Δ -> (NoTangent (), Δ)
125
+
126
+ function unbroadcast (x:: Base.AbstractArrayOrBroadcasted , dx)
127
+ N = ndims (dx)
128
+ if length (x) == length (dx)
129
+ ProjectTo (x)(dx) # handles trivial reshapes, offsets, structured matrices, row vectors
130
+ else
131
+ # This is an awful hack to get type-stable `dims`
132
+ dims = ntuple (d -> get (size (x), d, 1 ) == 1 ? d : N+ 1 , N)
133
+ ProjectTo (x)(sum (dx; dims))
134
+ end
135
+ end
136
+ unbroadcast (x:: Base.AbstractArrayOrBroadcasted , dx:: NoTangent ) = NoTangent ()
137
+
138
+ unbroadcast (x:: Number , dx) = ProjectTo (x)(sum (dx))
139
+ unbroadcast (f:: Function , df) = ProjectTo (x)(sum (df))
140
+ unbroadcast (x:: Base.RefValue , dx) = ProjectTo (x)(Ref (sum (dx)))
141
+
142
+ unbroadcast (:: Bool , dx) = NoTangent ()
143
+ unbroadcast (:: AbstractArray{Bool} , dx) = NoTangent ()
144
+ unbroadcast (:: AbstractArray{Bool} , :: NoTangent ) = NoTangent () # ambiguity
145
+ unbroadcast (:: Val , dx) = NoTangent ()
146
+ # Maybe more non-diff types? Some fallback?
147
+
148
+ unbroadcast (x:: T , dx) where {T<: Tuple{Any} } = ProjectTo (x)(Tangent {T} (sum (dx)))
149
+ function unbroadcast (x:: T , dx) where {T<: Tuple{Vararg{Any,N}} } where {N}
150
+ _print (" unbroadcast tuple" )
151
+ val = if length (x) == length (dx)
152
+ dx
153
+ else
154
+ sum (dx; dims= 2 : ndims (dx))
155
+ end
156
+ ProjectTo (x)(NTuple {length(x)} (val)) # Tangent
157
+ end
0 commit comments