1
1
2
2
"""
3
- tuplecast (f, args...)
3
+ unzip_broadcast (f, args...)
4
4
5
5
For a function `f` which returns a tuple, this is `== unzip(broadcast(f, args...))`,
6
6
but performed using `StructArrays` for efficiency. Used in the gradient of broadcasting.
7
7
8
8
# Examples
9
9
```
10
- julia> using ChainRules: tuplecast , unzip
10
+ julia> using ChainRules: unzip_broadcast , unzip
11
11
12
- julia> tuplecast (x -> (x,2x), 1:3)
12
+ julia> unzip_broadcast (x -> (x,2x), 1:3)
13
13
([1, 2, 3], [2, 4, 6])
14
14
15
- julia> mats = @btime tuplecast ((x,y) -> (x+y, x-y), 1:1000, transpose(1:1000)); # 2 arrays, each 7.63 MiB
15
+ julia> mats = @btime unzip_broadcast ((x,y) -> (x+y, x-y), 1:1000, transpose(1:1000)); # 2 arrays, each 7.63 MiB
16
16
min 1.776 ms, mean 20.421 ms (4 allocations, 15.26 MiB)
17
17
18
18
julia> mats == @btime unzip(broadcast((x,y) -> (x+y, x-y), 1:1000, transpose(1:1000))) # intermediate matrix of tuples
19
19
min 2.660 ms, mean 40.007 ms (6 allocations, 30.52 MiB)
20
20
true
21
21
```
22
22
"""
23
- function tuplecast (f:: F , args... ) where {F}
23
+ function unzip_broadcast (f:: F , args... ) where {F}
24
24
T = Broadcast. combine_eltypes (f, args)
25
25
if isconcretetype (T)
26
- T <: Tuple || throw (ArgumentError (""" tuplecast (f, args) only works on functions returning a tuple,
26
+ T <: Tuple || throw (ArgumentError (""" unzip_broadcast (f, args) only works on functions returning a tuple,
27
27
but f = $(sprint (show, f)) returns type T = $T """ ))
28
28
end
29
29
# TODO allow GPU arrays, possibly just as a fallback unzip, but see also:
@@ -39,7 +39,7 @@ function tuplecast(f::F, args...) where {F}
39
39
end
40
40
end
41
41
42
- function ChainRulesCore. rrule (cfg:: RuleConfig{>:HasReverseMode} , :: typeof (tuplecast ), f:: F , args... ) where {F}
42
+ function ChainRulesCore. rrule (cfg:: RuleConfig{>:HasReverseMode} , :: typeof (unzip_broadcast ), f:: F , args... ) where {F}
43
43
y, back = rrule_via_ad (cfg, broadcast, f, args... )
44
44
z = unzip (y)
45
45
function untuplecast (dz)
@@ -53,23 +53,25 @@ function ChainRulesCore.rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(tuplec
53
53
end
54
54
55
55
# This is for testing, but the tests using it don't work.
56
- function rrule (cfg:: RuleConfig{>:HasReverseMode} , :: typeof (collect∘ tuplecast ), f, args... )
57
- y, back = rrule (cfg, tuplecast , f, args... )
56
+ function rrule (cfg:: RuleConfig{>:HasReverseMode} , :: typeof (collect∘ unzip_broadcast ), f, args... )
57
+ y, back = rrule (cfg, unzip_broadcast , f, args... )
58
58
return collect (y), back
59
59
end
60
60
61
+ #=
62
+
61
63
"""
62
- tuplemap (f, args...)
64
+ unzip_map (f, args...)
63
65
64
66
For a function `f` which returns a tuple, this is `== unzip(map(f, args...))`,
65
67
but performed using `StructArrays` for efficiency.
66
68
67
- Not in use at present, but see `tuplecast `.
69
+ Not in use at present, but see `unzip_broadcast `.
68
70
"""
69
- function tuplemap (f:: F , args... ) where {F}
71
+ function unzip_map (f::F, args...) where {F}
70
72
T = Broadcast.combine_eltypes(f, args)
71
73
if isconcretetype(T)
72
- T <: Tuple || throw (ArgumentError (""" tuplemap (f, args) only works on functions returning a tuple,
74
+ T <: Tuple || throw(ArgumentError("""unzip_map (f, args) only works on functions returning a tuple,
73
75
but f = $(sprint(show, f)) returns type T = $T"""))
74
76
end
75
77
# if any(a -> a isa CuArray, args)
@@ -78,17 +80,19 @@ function tuplemap(f::F, args...) where {F}
78
80
return StructArrays.components(StructArray(Iterators.map(f, args...)))
79
81
end
80
82
81
- function ChainRulesCore. rrule (cfg:: RuleConfig{>:HasReverseMode} , :: typeof (tuplemap ), f:: F , xs... ) where {F}
83
+ function ChainRulesCore.rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(unzip_map ), f::F, xs...) where {F}
82
84
y, back = rrule_via_ad(cfg, map, f, xs...)
83
85
z = unzip(y)
84
- function untuplemap (dz)
86
+ function ununzip_map (dz)
85
87
# dy = StructArray(map(unthunk, dz)) # fails for e.g. StructArray(([1,2,3], ZeroTangent()))
86
88
dy = broadcast(tuple, map(unthunk, dz)...)
87
89
return back(dy)
88
90
end
89
- return z, untuplemap
91
+ return z, ununzip_map
90
92
end
91
93
94
+ =#
95
+
92
96
"""
93
97
unzip(A)
94
98
@@ -114,7 +118,7 @@ function unzip(xs::AbstractArray)
114
118
x1 = first (xs)
115
119
x1 isa Tuple || throw (ArgumentError (" unzip only accepts arrays of tuples" ))
116
120
N = length (x1)
117
- return unzip (xs, Val (N)) # like Zygote's unzip, here this is the fallback case.
121
+ return unzip (xs, Val (N)) # like Zygote's unzip. Here this is the fallback case.
118
122
end
119
123
120
124
@generated function unzip (xs, :: Val{N} ) where {N}
0 commit comments