44
44
45
45
(:: ∂⃖{1 })(:: typeof (broadcasted), f, args... ) = split_bc_rule (f, args... )
46
46
(:: ∂⃖{1 })(:: typeof (broadcasted), f, arg:: Array ) = split_bc_rule (f, arg) # ambiguity
47
- function split_bc_rule (f:: F , args... ) where {F}
47
+ function split_bc_rule (f:: F , args:: Vararg{Any,N} ) where {F,N }
48
48
T = Broadcast. combine_eltypes (f, args)
49
49
TΔ = Core. Compiler. _return_type (derivatives_given_output, Tuple{T, F, map (eltype, args)... })
50
50
if eltype (T) == Bool
@@ -71,10 +71,11 @@ function split_bc_rule(f::F, args...) where {F}
71
71
dargs = map (unbroadcast, args, deltas) # ideally sum in unbroadcast could be part of splitcast?
72
72
(NoTangent (), NoTangent (), dargs... )
73
73
end
74
- return ys, length (args) == 1 ? back_2_one : back_2_many
74
+ return ys, N == 1 ? back_2_one : back_2_many
75
75
else
76
76
# Slow path: collect all the pullbacks & apply them later.
77
- # Since broadcast makes no guarantee about order, this does not bother to try to reverse it.
77
+ # (Since broadcast makes no guarantee about order of calls, and un-fusing
78
+ # can change the number of calls, this does not bother to try to reverse.)
78
79
_print (" path 3" )
79
80
ys, backs = splitcast (∂⃖ {1} (), f, args... )
80
81
function back_3 (dys)
@@ -84,15 +85,21 @@ function split_bc_rule(f::F, args...) where {F}
84
85
dargs = map (unbroadcast, args, Base. tail (deltas)) # no real need to close over args here
85
86
(NoTangent (), sum (first (deltas)), dargs... )
86
87
end
88
+ back_3 (:: AbstractZero ) = (NoTangent (), map (Returns (ZeroTangent ()), args)... )
87
89
return ys, back_3
88
90
end
89
91
end
90
92
91
- # This uses "mulltimap"-like constructs:
93
+ # Skip AD'ing through the axis computation
94
+ function (:: ∂⃖{1 })(:: typeof (Base. Broadcast. instantiate), bc:: Base.Broadcast.Broadcasted )
95
+ uninstantiate (Δ) = Core. tuple (NoTangent (), Δ)
96
+ return Base. Broadcast. instantiate (bc), uninstantiate
97
+ end
98
+
99
+ # This uses "multimap"-like constructs:
92
100
93
101
using StructArrays
94
102
splitmap (f, args... ) = StructArrays. components (StructArray (Iterators. map (f, args... )))
95
- # warning: splitmap(identity, [1,2,3,4]) === NamedTuple()
96
103
splitcast (f, args... ) = StructArrays. components (StructArray (Broadcast. instantiate (Broadcast. broadcasted (f, args... ))))
97
104
98
105
#=
156
163
(:: ∂⃖{1 })(:: typeof (broadcasted), :: typeof (+ ), args... ) = split_bc_plus (args... )
157
164
(:: ∂⃖{1 })(:: typeof (broadcasted), :: typeof (+ ), arg:: Array ) = split_bc_plus (arg) # ambiguity
158
165
function split_bc_plus (xs... ) where {F}
159
- broadcasted (+ , xs... ), Δ -> let Δun = unthunk (Δ )
166
+ broadcasted (+ , xs... ), Δraw -> let Δ = unthunk (Δraw )
160
167
_print (" broadcast +" )
161
- (NoTangent (), NoTangent (), map (x -> unbroadcast (x, Δun ), xs)... )
168
+ (NoTangent (), NoTangent (), map (x -> unbroadcast (x, Δ ), xs)... )
162
169
end
163
170
end
164
171
Base. eltype (bc:: Broadcast.Broadcasted{<:Any, <:Any, typeof(+), <:Tuple} ) =
@@ -167,20 +174,20 @@ Base.eltype(bc::Broadcast.Broadcasted{<:Any, <:Any, typeof(+), <:Tuple}) =
167
174
(:: ∂⃖{1 })(:: typeof (copy), bc:: Broadcast.Broadcasted ) = copy (bc), Δ -> (NoTangent (), Δ)
168
175
169
176
function (:: ∂⃖{1 })(:: typeof (broadcasted), :: typeof (- ), x, y)
170
- broadcasted (- , x, y), Δ -> let Δun = unthunk (Δ )
177
+ broadcasted (- , x, y), Δraw -> let Δ = unthunk (Δraw )
171
178
_print (" broadcast -" )
172
- (NoTangent (), NoTangent (), unbroadcast (x, Δun ), - unbroadcast (y, Δun ))
179
+ (NoTangent (), NoTangent (), unbroadcast (x, Δ ), - unbroadcast (y, Δ ))
173
180
# Ideally you could fuse the - into unbroadcast, mapreduce() not sum, when y is a smaller array
174
181
end
175
182
end
176
183
177
184
using LinearAlgebra: dot
178
185
179
186
function (:: ∂⃖{1 })(:: typeof (broadcasted), :: typeof (* ), x, y) # should this be vararg, or will laziness handle it?
180
- broadcasted (* , x, y), Δ -> let Δun = unthunk (Δ )
187
+ broadcasted (* , x, y), Δraw -> let Δ = unthunk (Δraw )
181
188
_print (" broadcast *" )
182
- dx = eltype (x)== Bool ? NoTangent () : x isa Number ? dot (y, Δun ) : unbroadcast (x, Δun .* conj .(y))
183
- dy = eltype (y)== Bool ? NoTangent () : y isa Number ? dot (x, Δun ) : unbroadcast (y, Δun .* conj .(x))
189
+ dx = eltype (x)== Bool ? NoTangent () : x isa Number ? dot (y, Δ ) : unbroadcast (x, Δ .* conj .(y))
190
+ dy = eltype (y)== Bool ? NoTangent () : y isa Number ? dot (x, Δ ) : unbroadcast (y, Δ .* conj .(x))
184
191
# When x is an array but a smaller one, instead of dot you may be able to use mapreduce()
185
192
# Will things like this work? Ref([1,2]) .* [1,2,3]
186
193
(NoTangent (), NoTangent (), dx, dy)
0 commit comments