@@ -29,28 +29,30 @@ function (∂ₙ::∂☆{N})(zc::ZeroBundle{N, typeof(copy)},
29
29
return r
30
30
end
31
31
32
- _print (s) = nothing
33
- # _print(s) = printstyled(s, "\n"; color=:magenta)
32
+ # Reverse mode broadcast rules
33
+
34
+ using ChainRulesCore: derivatives_given_output
35
+
36
+ # _print(s) = nothing
37
+ _print (s) = printstyled (s, " \n " ; color= :magenta )
34
38
35
39
# Broadcast over one element is just map
36
40
function (∂⃖ₙ: :∂⃖ {N})(:: typeof (broadcasted), f, a:: Array ) where {N}
37
41
_print (" path 0" )
38
42
∂⃖ₙ (map, f, a)
39
43
end
40
44
41
- using ChainRulesCore: derivatives_given_output
42
-
43
45
(:: ∂⃖{1 })(:: typeof (broadcasted), f, args... ) = split_bc_rule (f, args... )
44
46
(:: ∂⃖{1 })(:: typeof (broadcasted), f, arg:: Array ) = split_bc_rule (f, arg) # ambiguity
45
47
function split_bc_rule (f:: F , args... ) where {F}
46
48
T = Broadcast. combine_eltypes (f, args)
47
- if T == Bool
48
- # Trivial case
49
+ TΔ = Core. Compiler. _return_type (derivatives_given_output, Tuple{T, F, map (eltype, args)... })
50
+ if eltype (T) == Bool
51
+ # Trivial case: non-differentiable output
49
52
_print (" path 1" )
50
53
back_1 (_) = ntuple (Returns (ZeroTangent ()), length (args)+ 2 )
51
54
return f .(args... ), back_1
52
- elseif isconcretetype (Core. Compiler. _return_type (
53
- derivatives_given_output, Tuple{T, F, map (eltype, args)... }))
55
+ elseif T <: Number && isconcretetype (TΔ)
54
56
# Fast path: just broadcast, and use x & y to find derivative.
55
57
ys = f .(args... )
56
58
_print (" path 2" )
@@ -65,8 +67,9 @@ function split_bc_rule(f::F, args...) where {F}
65
67
return ys, back_2
66
68
else
67
69
# Slow path: collect all the pullbacks & apply them later.
70
+ # Since broadcast makes no guarantee about order, this does not bother to try to reverse it.
68
71
_print (" path 3" )
69
- ys, backs = splitcast (rrule_via_ad, DiffractorRuleConfig (), f, args... )
72
+ ys, backs = splitcast (∂⃖ {1} (), f, args... )
70
73
function back_3 (dys)
71
74
deltas = splitmap (backs, unthunk (dys)) do back, dy
72
75
map (unthunk, back (dy))
@@ -78,8 +81,11 @@ function split_bc_rule(f::F, args...) where {F}
78
81
end
79
82
end
80
83
84
+ # This uses "mulltimap"-like constructs:
85
+
81
86
using StructArrays
82
- splitmap (f, args... ) = StructArrays. components (StructArray (Iterators. map (f, args... ))) # warning: splitmap(identity, [1,2,3,4]) === NamedTuple()
87
+ splitmap (f, args... ) = StructArrays. components (StructArray (Iterators. map (f, args... )))
88
+ # warning: splitmap(identity, [1,2,3,4]) === NamedTuple()
83
89
splitcast (f, args... ) = StructArrays. components (StructArray (Broadcast. instantiate (Broadcast. broadcasted (f, args... ))))
84
90
85
91
# For certain cheap operations we can easily allow fused broadcast:
107
113
108
114
using LinearAlgebra: dot
109
115
110
- function (:: ∂⃖{1 })(:: typeof (broadcasted), :: typeof (* ), x, y)
116
+ function (:: ∂⃖{1 })(:: typeof (broadcasted), :: typeof (* ), x, y) # should this be vararg, or will laziness handle it?
111
117
broadcasted (* , x, y), Δ -> let Δun = unthunk (Δ)
112
118
_print (" broadcast *" )
113
119
dx = eltype (x)== Bool ? NoTangent () : x isa Number ? dot (y, Δun) : unbroadcast (x, Δun .* conj .(y))
@@ -117,41 +123,88 @@ function (::∂⃖{1})(::typeof(broadcasted), ::typeof(*), x, y)
117
123
(NoTangent (), NoTangent (), dx, dy)
118
124
end
119
125
end
126
+ # Alternative to `x isa Number` etc above... but not quite right!
127
+ # (::∂⃖{1})(::typeof(broadcasted), ::typeof(*), x, y::Number) = rrule_via_ad(DiffractorRuleConfig(), *, x, y)
128
+
129
+ function (:: ∂⃖{1 })(:: typeof (broadcasted), :: typeof (Base. literal_pow), :: typeof (^ ), x, :: Val{2} )
130
+ _print (" broadcast ^2" )
131
+ broadcasted (* , x, x), Δ -> begin
132
+ dx = unbroadcast (x, 2 .* Δ .* conj .(x))
133
+ (NoTangent (), NoTangent (), NoTangent (), dx, NoTangent ())
134
+ end
135
+ end
136
+ function (:: ∂⃖{1 })(:: typeof (broadcasted), :: typeof (Base. literal_pow), :: typeof (^ ), x:: Number , :: Val{2} )
137
+ _print (" simple ^2" )
138
+ x^ 2 , Δ -> (NoTangent (), NoTangent (), NoTangent (), 2 * Δ * conj (x), NoTangent ())
139
+ end
140
+
141
+ # function (::∂⃖{1})(::typeof(broadcasted), ::typeof(/), x, y) # not obvious whether this is better than automatic
142
+ # broadcasted(/, x, y), Δ -> let Δun = unthunk(Δ)
143
+ # _print("broadcast /")
144
+ # dx = unbroadcast(x, Δ ./ conj.(y))
145
+ # dy = unbroadcast(y, .-Δ .* conj.(res ./ y))
146
+ # (NoTangent(), NoTangent(), dx, dy)
147
+ # end
148
+ # end
149
+ function (:: ∂⃖{1 })(:: typeof (broadcasted), :: typeof (/ ), x, y:: Number )
150
+ _print (" simple /" )
151
+ z, back = ∂⃖ {1} ()(/ , x, y)
152
+ z, Δ -> begin
153
+ _, dx, dy = back (Δ)
154
+ (NoTangent (), NoTangent (), dx, dy) # maybe there should be a funciton for this? Use for conj, identity too
155
+ end
156
+ end
120
157
121
158
(:: ∂⃖{1 })(:: typeof (broadcasted), :: typeof (conj), x) =
122
159
broadcasted (conj, x), Δ -> (NoTangent (), conj (unthunk (Δ)))
123
160
(:: ∂⃖{1 })(:: typeof (broadcasted), :: typeof (conj), x:: AbstractArray{Real} ) =
124
161
x, Δ -> (NoTangent (), Δ)
125
162
163
+ (:: ∂⃖{1 })(:: typeof (broadcasted), :: typeof (identity), x) =
164
+ x, Δ -> (NoTangent (), Δ)
165
+
166
+ # All broadcasts use `unbroadcast` to reduce to correct shape:
167
+
126
168
function unbroadcast (x:: Base.AbstractArrayOrBroadcasted , dx)
127
169
N = ndims (dx)
128
170
if length (x) == length (dx)
129
171
ProjectTo (x)(dx) # handles trivial reshapes, offsets, structured matrices, row vectors
130
172
else
131
- # This is an awful hack to get type-stable `dims`
132
- dims = ntuple (d -> get (size (x), d, 1 ) == 1 ? d : N+ 1 , N)
173
+ dims = ntuple (d -> get (size (x), d, 1 ) == 1 ? d : N+ 1 , N) # awful hack to get type-stable `dims`
133
174
ProjectTo (x)(sum (dx; dims))
134
175
end
135
176
end
136
177
unbroadcast (x:: Base.AbstractArrayOrBroadcasted , dx:: NoTangent ) = NoTangent ()
137
178
179
+ unbroadcast (x:: T , dx) where {T<: Tuple{Any} } = ProjectTo (x)(Tangent {T} (sum (dx)))
180
+ function unbroadcast (x:: T , dx) where {T<: Tuple{Vararg{Any,N}} } where {N}
181
+ _print (" unbroadcast tuple" )
182
+ val = if length (x) == length (dx)
183
+ dx
184
+ else
185
+ sum (dx; dims= 2 : ndims (dx))
186
+ end
187
+ ProjectTo (x)(NTuple {length(x)} (val)) # Tangent
188
+ end
189
+
190
+ unbroadcast (f:: Function , df) = sum (df)
138
191
unbroadcast (x:: Number , dx) = ProjectTo (x)(sum (dx))
139
- unbroadcast (f:: Function , df) = ProjectTo (x)(sum (df))
140
192
unbroadcast (x:: Base.RefValue , dx) = ProjectTo (x)(Ref (sum (dx)))
141
193
142
194
unbroadcast (:: Bool , dx) = NoTangent ()
143
195
unbroadcast (:: AbstractArray{Bool} , dx) = NoTangent ()
144
196
unbroadcast (:: AbstractArray{Bool} , :: NoTangent ) = NoTangent () # ambiguity
145
197
unbroadcast (:: Val , dx) = NoTangent ()
146
- # Maybe more non-diff types? Some fallback?
147
198
148
- unbroadcast (x:: T , dx) where {T<: Tuple{Any} } = ProjectTo (x)(Tangent {T} (sum (dx)))
149
- function unbroadcast (x:: T , dx) where {T<: Tuple{Vararg{Any,N}} } where {N}
150
- _print (" unbroadcast tuple" )
151
- val = if length (x) == length (dx)
152
- dx
199
+ function unbroadcast (x, dx)
200
+ p = ProjectTo (x)
201
+ if dx isa AbstractZero || p isa ProjectTo{<: AbstractZero }
202
+ return NoTangent ()
203
+ end
204
+ b = Broadcast. broadcastable (x)
205
+ if b isa Ref # then x is scalar under broadcast
206
+ return p (sum (dx))
153
207
else
154
- sum (dx; dims = 2 : ndims (dx) )
208
+ error ( " don't know how to handle broadcast gradient for x:: $( typeof (x)) " )
155
209
end
156
- ProjectTo (x)(NTuple {length(x)} (val)) # Tangent
157
210
end
0 commit comments