Skip to content

Commit 63797fd

Browse files
committed
avoid StructArrays sometimes, add unzip
1 parent 66f2b6a commit 63797fd

File tree

1 file changed

+66
-3
lines changed

1 file changed

+66
-3
lines changed

src/stage1/broadcast.jl

Lines changed: 66 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -56,15 +56,22 @@ function split_bc_rule(f::F, args...) where {F}
5656
# Fast path: just broadcast, and use x & y to find derivative.
5757
ys = f.(args...)
5858
_print("path 2")
59-
function back_2(dys)
59+
function back_2_one(dys) # For f.(x) we do not need StructArrays / unzip at all
60+
delta = broadcast(unthunk(dys), ys, args...) do dy, y, a
61+
das = only(derivatives_given_output(y, f, a))
62+
dy * conj(only(das))
63+
end
64+
(NoTangent(), NoTangent(), unbroadcast(only(args), delta))
65+
end
66+
function back_2_many(dys)
6067
deltas = splitcast(unthunk(dys), ys, args...) do dy, y, as...
6168
das = only(derivatives_given_output(y, f, as...))
6269
map(da -> dy * conj(da), das)
6370
end
64-
dargs = map(unbroadcast, args, deltas)
71+
dargs = map(unbroadcast, args, deltas) # ideally sum in unbroadcast could be part of splitcast?
6572
(NoTangent(), NoTangent(), dargs...)
6673
end
67-
return ys, back_2
74+
return ys, length(args)==1 ? back_2_one : back_2_many
6875
else
6976
# Slow path: collect all the pullbacks & apply them later.
7077
# Since broadcast makes no guarantee about order, this does not bother to try to reverse it.
@@ -88,6 +95,62 @@ splitmap(f, args...) = StructArrays.components(StructArray(Iterators.map(f, args
8895
# warning: splitmap(identity, [1,2,3,4]) === NamedTuple()
8996
splitcast(f, args...) = StructArrays.components(StructArray(Broadcast.instantiate(Broadcast.broadcasted(f, args...))))
9097

98+
#=
99+
# This is how you could handle CuArrays, route them to unzip(map(...)) fallback path.
100+
# Maybe 2nd derivatives too, to avoid writing a gradient for splitcast, rule for unzip is easy.
101+
102+
function Diffractor.splitmap(f, args...)
103+
if any(a -> a isa CuArray, args)
104+
Diffractor._print("unzip splitmap")
105+
unzip(map(f, args...))
106+
else
107+
StructArrays.components(StructArray(Iterators.map(f, args...)))
108+
end
109+
end
110+
function Diffractor.splitcast(f, args...)
111+
if any(a -> a isa CuArray, args)
112+
Diffractor._print("unzip splitcast")
113+
unzip(broadcast(f, args...))
114+
else
115+
StructArrays.components(StructArray(Broadcast.instantiate(Broadcast.broadcasted(f, args...))))
116+
end
117+
end
118+
119+
gradient(x -> sum(log.(x) .+ x'), cu([1,2,3]))[1]
120+
gradient(x -> sum(sqrt.(atan.(x, x'))), cu([1,2,3]))[1]
121+
122+
=#
123+
124+
function unzip(xs::AbstractArray)
125+
x1 = first(xs)
126+
x1 isa Tuple || throw(ArgumentError("unzip only accepts arrays of tuples"))
127+
N = length(x1)
128+
unzip(xs, Val(N)) # like Zygote's unzip
129+
end
130+
@generated function unzip(xs, ::Val{N}) where {N}
131+
each = [:(map($(Get(i)), xs)) for i in 1:N]
132+
Expr(:tuple, each...)
133+
end
134+
unzip(xs::AbstractArray{Tuple{T}}) where {T} = (reinterpret(T, xs),) # best case, no copy
135+
@generated function unzip(xs::AbstractArray{Ts}) where {Ts<:Tuple}
136+
each = if count(!Base.issingletontype, Ts.parameters) < 2
137+
# good case, no copy of data, some trivial arrays
138+
[Base.issingletontype(T) ? :(similar(xs, $T)) : :(reinterpret($T, xs)) for T in Ts.parameters]
139+
else
140+
[:(map($(Get(i)), xs)) for i in 1:length(fieldnames(Ts))]
141+
end
142+
Expr(:tuple, each...)
143+
end
144+
145+
struct Get{i} end
146+
Get(i) = Get{Int(i)}()
147+
(::Get{i})(x) where {i} = x[i]
148+
149+
function ChainRulesCore.rrule(::typeof(unzip), xs::AbstractArray)
150+
rezip(dy) = (NoTangent(), tuple.(unthunk(dy)...))
151+
return unzip(xs), rezip
152+
end
153+
91154
# For certain cheap operations we can easily allow fused broadcast:
92155

93156
(::∂⃖{1})(::typeof(broadcasted), ::typeof(+), args...) = split_bc_plus(args...)

0 commit comments

Comments
 (0)