@@ -56,15 +56,22 @@ function split_bc_rule(f::F, args...) where {F}
56
56
# Fast path: just broadcast, and use x & y to find derivative.
57
57
ys = f .(args... )
58
58
_print (" path 2" )
59
- function back_2 (dys)
59
+ function back_2_one (dys) # For f.(x) we do not need StructArrays / unzip at all
60
+ delta = broadcast (unthunk (dys), ys, args... ) do dy, y, a
61
+ das = only (derivatives_given_output (y, f, a))
62
+ dy * conj (only (das))
63
+ end
64
+ (NoTangent (), NoTangent (), unbroadcast (only (args), delta))
65
+ end
66
+ function back_2_many (dys)
60
67
deltas = splitcast (unthunk (dys), ys, args... ) do dy, y, as...
61
68
das = only (derivatives_given_output (y, f, as... ))
62
69
map (da -> dy * conj (da), das)
63
70
end
64
- dargs = map (unbroadcast, args, deltas)
71
+ dargs = map (unbroadcast, args, deltas) # ideally sum in unbroadcast could be part of splitcast?
65
72
(NoTangent (), NoTangent (), dargs... )
66
73
end
67
- return ys, back_2
74
+ return ys, length (args) == 1 ? back_2_one : back_2_many
68
75
else
69
76
# Slow path: collect all the pullbacks & apply them later.
70
77
# Since broadcast makes no guarantee about order, this does not bother to try to reverse it.
@@ -88,6 +95,62 @@ splitmap(f, args...) = StructArrays.components(StructArray(Iterators.map(f, args
88
95
# warning: splitmap(identity, [1,2,3,4]) === NamedTuple()
89
96
splitcast (f, args... ) = StructArrays. components (StructArray (Broadcast. instantiate (Broadcast. broadcasted (f, args... ))))
90
97
98
+ #=
99
+ # This is how you could handle CuArrays, route them to unzip(map(...)) fallback path.
100
+ # Maybe 2nd derivatives too, to avoid writing a gradient for splitcast, rule for unzip is easy.
101
+
102
+ function Diffractor.splitmap(f, args...)
103
+ if any(a -> a isa CuArray, args)
104
+ Diffractor._print("unzip splitmap")
105
+ unzip(map(f, args...))
106
+ else
107
+ StructArrays.components(StructArray(Iterators.map(f, args...)))
108
+ end
109
+ end
110
+ function Diffractor.splitcast(f, args...)
111
+ if any(a -> a isa CuArray, args)
112
+ Diffractor._print("unzip splitcast")
113
+ unzip(broadcast(f, args...))
114
+ else
115
+ StructArrays.components(StructArray(Broadcast.instantiate(Broadcast.broadcasted(f, args...))))
116
+ end
117
+ end
118
+
119
+ gradient(x -> sum(log.(x) .+ x'), cu([1,2,3]))[1]
120
+ gradient(x -> sum(sqrt.(atan.(x, x'))), cu([1,2,3]))[1]
121
+
122
+ =#
123
+
124
+ function unzip (xs:: AbstractArray )
125
+ x1 = first (xs)
126
+ x1 isa Tuple || throw (ArgumentError (" unzip only accepts arrays of tuples" ))
127
+ N = length (x1)
128
+ unzip (xs, Val (N)) # like Zygote's unzip
129
+ end
130
+ @generated function unzip (xs, :: Val{N} ) where {N}
131
+ each = [:(map ($ (Get (i)), xs)) for i in 1 : N]
132
+ Expr (:tuple , each... )
133
+ end
134
+ unzip (xs:: AbstractArray{Tuple{T}} ) where {T} = (reinterpret (T, xs),) # best case, no copy
135
+ @generated function unzip (xs:: AbstractArray{Ts} ) where {Ts<: Tuple }
136
+ each = if count (! Base. issingletontype, Ts. parameters) < 2
137
+ # good case, no copy of data, some trivial arrays
138
+ [Base. issingletontype (T) ? :(similar (xs, $ T)) : :(reinterpret ($ T, xs)) for T in Ts. parameters]
139
+ else
140
+ [:(map ($ (Get (i)), xs)) for i in 1 : length (fieldnames (Ts))]
141
+ end
142
+ Expr (:tuple , each... )
143
+ end
144
+
145
+ struct Get{i} end
146
+ Get (i) = Get {Int(i)} ()
147
+ (:: Get{i} )(x) where {i} = x[i]
148
+
149
+ function ChainRulesCore. rrule (:: typeof (unzip), xs:: AbstractArray )
150
+ rezip (dy) = (NoTangent (), tuple .(unthunk (dy)... ))
151
+ return unzip (xs), rezip
152
+ end
153
+
91
154
# For certain cheap operations we can easily allow fused broadcast:
92
155
93
156
(:: ∂⃖{1 })(:: typeof (broadcasted), :: typeof (+ ), args... ) = split_bc_plus (args... )
0 commit comments