Skip to content

Commit ccbe561

Browse files
committed
tidy unzipped
1 parent 45102c5 commit ccbe561

File tree

3 files changed

+21
-39
lines changed

3 files changed

+21
-39
lines changed

src/ChainRules.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ using Base.Broadcast: materialize, materialize!, broadcasted, Broadcasted, broad
44
using ChainRulesCore
55
using Compat
66
using Distributed
7+
using GPUArraysCore: AbstractGPUArrayStyle
78
using IrrationalConstants: logtwo, logten
89
using LinearAlgebra
910
using LinearAlgebra.BLAS

src/unzipped.jl

Lines changed: 18 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
#####
2+
##### broadcast
3+
#####
14

25
"""
36
unzip_broadcast(f, args...)
@@ -26,17 +29,18 @@ function unzip_broadcast(f::F, args...) where {F}
2629
T <: Tuple || throw(ArgumentError("""unzip_broadcast(f, args) only works on functions returning a tuple,
2730
but f = $(sprint(show, f)) returns type T = $T"""))
2831
end
29-
# TODO allow GPU arrays, possibly just as a fallback unzip, but see also:
30-
# https://github.com/JuliaArrays/StructArrays.jl/issues/150
31-
# if any(a -> a isa CuArray, args)
32-
# return unzip(broadcast(f, args...))
33-
# end
3432
bc = Broadcast.instantiate(Broadcast.broadcasted(f, args...))
35-
if Broadcast.BroadcastStyle(typeof(bc)) isa Broadcast.AbstractArrayStyle
33+
bcs = Broadcast.BroadcastStyle(typeof(bc))
34+
if bcs isa AbstractGPUArrayStyle
35+
# This is a crude way to allow GPU arrays, not currently tested, TODO.
36+
# See also https://github.com/JuliaArrays/StructArrays.jl/issues/150
37+
return unzip(broadcast(f, args...))
38+
elseif bcs isa Broadcast.AbstractArrayStyle
3639
return StructArrays.components(StructArray(bc))
3740
else
3841
return unzip(broadcast(f, args...)) # e.g. tuples
3942
end
43+
# TODO maybe this if-else can be replaced by methods of `unzip(:::Broadcast.Broadcasted)`?
4044
end
4145

4246
function ChainRulesCore.rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(unzip_broadcast), f::F, args...) where {F}
@@ -58,40 +62,17 @@ function rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(collect∘unzip_broad
5862
return collect(y), back
5963
end
6064

61-
#=
65+
#####
66+
##### map
67+
#####
6268

63-
"""
64-
unzip_map(f, args...)
65-
66-
For a function `f` which returns a tuple, this is `== unzip(map(f, args...))`,
67-
but performed using `StructArrays` for efficiency.
68-
69-
Not in use at present, but see `unzip_broadcast`.
70-
"""
71-
function unzip_map(f::F, args...) where {F}
72-
T = Broadcast.combine_eltypes(f, args)
73-
if isconcretetype(T)
74-
T <: Tuple || throw(ArgumentError("""unzip_map(f, args) only works on functions returning a tuple,
75-
but f = $(sprint(show, f)) returns type T = $T"""))
76-
end
77-
# if any(a -> a isa CuArray, args)
78-
# return unzip(map(f, args...))
79-
# end
80-
return StructArrays.components(StructArray(Iterators.map(f, args...)))
81-
end
69+
# `unzip_map` can use `StructArrays.components(StructArray(Iterators.map(f, args...)))`,
70+
# will be useful for the gradient of `map` etc.
8271

83-
function ChainRulesCore.rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(unzip_map), f::F, xs...) where {F}
84-
y, back = rrule_via_ad(cfg, map, f, xs...)
85-
z = unzip(y)
86-
function ununzip_map(dz)
87-
# dy = StructArray(map(unthunk, dz)) # fails for e.g. StructArray(([1,2,3], ZeroTangent()))
88-
dy = broadcast(tuple, map(unthunk, dz)...)
89-
return back(dy)
90-
end
91-
return z, ununzip_map
92-
end
9372

94-
=#
73+
#####
74+
##### unzip
75+
#####
9576

9677
"""
9778
unzip(A)

test/unzipped.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ using ChainRules: unzip_broadcast, unzip #, unzip_map
1717
@test fun(tuple, [1,2,3], [4 5]) == ([1 1; 2 2; 3 3], [4 5; 4 5; 4 5])
1818
end
1919

20-
if fun == unzipmap
20+
if contains(string(fun), "map")
2121
@test fun(tuple, (1,2,3), (4,5,6)) == ((1, 2, 3), (4, 5, 6))
2222
else
2323
@test fun(tuple, (1,2,3), (4,5,6)) == ((1, 2, 3), (4, 5, 6))
@@ -26,7 +26,7 @@ using ChainRules: unzip_broadcast, unzip #, unzip_map
2626
end
2727
@test fun(tuple, (1,2,3), [4,5,6]) == ([1, 2, 3], [4, 5, 6]) # mix tuple & vector
2828
end
29-
29+
3030
@testset "rrules" begin
3131
# These exist to allow for second derivatives
3232

0 commit comments

Comments
 (0)