1
+ # ####
2
+ # #### broadcast
3
+ # ####
1
4
2
5
"""
3
6
unzip_broadcast(f, args...)
@@ -26,17 +29,18 @@ function unzip_broadcast(f::F, args...) where {F}
26
29
T <: Tuple || throw (ArgumentError (""" unzip_broadcast(f, args) only works on functions returning a tuple,
27
30
but f = $(sprint (show, f)) returns type T = $T """ ))
28
31
end
29
- # TODO allow GPU arrays, possibly just as a fallback unzip, but see also:
30
- # https://github.com/JuliaArrays/StructArrays.jl/issues/150
31
- # if any(a -> a isa CuArray, args)
32
- # return unzip(broadcast(f, args...))
33
- # end
34
32
bc = Broadcast. instantiate (Broadcast. broadcasted (f, args... ))
35
- if Broadcast. BroadcastStyle (typeof (bc)) isa Broadcast. AbstractArrayStyle
33
+ bcs = Broadcast. BroadcastStyle (typeof (bc))
34
+ if bcs isa AbstractGPUArrayStyle
35
+ # This is a crude way to allow GPU arrays, not currently tested, TODO .
36
+ # See also https://github.com/JuliaArrays/StructArrays.jl/issues/150
37
+ return unzip (broadcast (f, args... ))
38
+ elseif bcs isa Broadcast. AbstractArrayStyle
36
39
return StructArrays. components (StructArray (bc))
37
40
else
38
41
return unzip (broadcast (f, args... )) # e.g. tuples
39
42
end
43
+ # TODO maybe this if-else can be replaced by methods of `unzip(:::Broadcast.Broadcasted)`?
40
44
end
41
45
42
46
function ChainRulesCore. rrule (cfg:: RuleConfig{>:HasReverseMode} , :: typeof (unzip_broadcast), f:: F , args... ) where {F}
@@ -58,40 +62,17 @@ function rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(collect∘unzip_broad
58
62
return collect (y), back
59
63
end
60
64
61
- #=
65
+ # ####
66
+ # #### map
67
+ # ####
62
68
63
- """
64
- unzip_map(f, args...)
65
-
66
- For a function `f` which returns a tuple, this is `== unzip(map(f, args...))`,
67
- but performed using `StructArrays` for efficiency.
68
-
69
- Not in use at present, but see `unzip_broadcast`.
70
- """
71
- function unzip_map(f::F, args...) where {F}
72
- T = Broadcast.combine_eltypes(f, args)
73
- if isconcretetype(T)
74
- T <: Tuple || throw(ArgumentError("""unzip_map(f, args) only works on functions returning a tuple,
75
- but f = $(sprint(show, f)) returns type T = $T"""))
76
- end
77
- # if any(a -> a isa CuArray, args)
78
- # return unzip(map(f, args...))
79
- # end
80
- return StructArrays.components(StructArray(Iterators.map(f, args...)))
81
- end
69
+ # `unzip_map` can use `StructArrays.components(StructArray(Iterators.map(f, args...)))`,
70
+ # will be useful for the gradient of `map` etc.
82
71
83
- function ChainRulesCore.rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(unzip_map), f::F, xs...) where {F}
84
- y, back = rrule_via_ad(cfg, map, f, xs...)
85
- z = unzip(y)
86
- function ununzip_map(dz)
87
- # dy = StructArray(map(unthunk, dz)) # fails for e.g. StructArray(([1,2,3], ZeroTangent()))
88
- dy = broadcast(tuple, map(unthunk, dz)...)
89
- return back(dy)
90
- end
91
- return z, ununzip_map
92
- end
93
72
94
- =#
73
+ # ####
74
+ # #### unzip
75
+ # ####
95
76
96
77
"""
97
78
unzip(A)
0 commit comments