Skip to content

Commit a365fe3

Browse files
committed
rename with unzip
1 parent b6d279b commit a365fe3

File tree

5 files changed

+36
-32
lines changed

5 files changed

+36
-32
lines changed

src/ChainRules.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ using ChainRulesCore: derivatives_given_output
2424
const CommutativeMulNumber = Union{Real,Complex}
2525

2626
# StructArrays
27-
include("tuplecast.jl")
27+
include("unzipped.jl")
2828

2929
include("rulesets/Core/core.jl")
3030

src/rulesets/Base/broadcast.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -72,11 +72,11 @@ function split_bc_derivatives(f::F, args::Vararg{Any,N}) where {F,N}
7272
@debug("split broadcasting derivatives", f, N)
7373
ys = f.(args...)
7474
function bc_many_back(dys)
75-
deltas = tuplecast(unthunk(dys), ys, args...) do dy, y, as...
75+
deltas = unzip_broadcast(unthunk(dys), ys, args...) do dy, y, as...
7676
das = only(derivatives_given_output(y, f, as...))
7777
map(da -> dy * conj(da), das) # possibly this * should be made nan-safe.
7878
end
79-
dargs = map(unbroadcast, args, deltas) # ideally sum in unbroadcast could be part of tuplecast?
79+
dargs = map(unbroadcast, args, deltas) # ideally sum in unbroadcast could be part of unzip_broadcast?
8080
return (TRI_NO..., dargs...)
8181
end
8282
bc_many_back(z::AbstractZero) = (TRI_NO..., map(Returns(z), args)...)
@@ -105,7 +105,7 @@ split_bc_forwards(cfg::RuleConfig{>:HasForwardsMode}, f::F, arg) where {F} = spl
105105
split_bc_forwards(cfg::RuleConfig, f::F, arg) where {F} = split_bc_inner(frule, cfg, f, arg)
106106
function split_bc_inner(frule_fun::R, cfg::RuleConfig, f::F, arg) where {R,F}
107107
@debug("split broadcasting forwards", frule_fun, f)
108-
ys, ydots = tuplecast(arg) do a
108+
ys, ydots = unzip_broadcast(arg) do a
109109
frule_fun(cfg, (NoTangent(), one(a)), f, a)
110110
end
111111
function back_forwards(dys)
@@ -124,11 +124,11 @@ end
124124

125125
function split_bc_pullbacks(cfg::RCR, f::F, args::Vararg{Any,N}) where {F,N}
126126
@debug("split broadcasting generic", f, N)
127-
ys3, backs = tuplecast(args...) do a...
127+
ys3, backs = unzip_broadcast(args...) do a...
128128
rrule_via_ad(cfg, f, a...)
129129
end
130130
function back_generic(dys)
131-
deltas = tuplecast(backs, unthunk(dys)) do back, dy # (could be map, sizes match)
131+
deltas = unzip_broadcast(backs, unthunk(dys)) do back, dy # (could be map, sizes match)
132132
map(unthunk, back(dy))
133133
end
134134
dargs = map(unbroadcast, args, Base.tail(deltas))

src/tuplecast.jl renamed to src/unzipped.jl

Lines changed: 21 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,29 @@
11

22
"""
3-
tuplecast(f, args...)
3+
unzip_broadcast(f, args...)
44
55
For a function `f` which returns a tuple, this is `== unzip(broadcast(f, args...))`,
66
but performed using `StructArrays` for efficiency. Used in the gradient of broadcasting.
77
88
# Examples
99
```
10-
julia> using ChainRules: tuplecast, unzip
10+
julia> using ChainRules: unzip_broadcast, unzip
1111
12-
julia> tuplecast(x -> (x,2x), 1:3)
12+
julia> unzip_broadcast(x -> (x,2x), 1:3)
1313
([1, 2, 3], [2, 4, 6])
1414
15-
julia> mats = @btime tuplecast((x,y) -> (x+y, x-y), 1:1000, transpose(1:1000)); # 2 arrays, each 7.63 MiB
15+
julia> mats = @btime unzip_broadcast((x,y) -> (x+y, x-y), 1:1000, transpose(1:1000)); # 2 arrays, each 7.63 MiB
1616
min 1.776 ms, mean 20.421 ms (4 allocations, 15.26 MiB)
1717
1818
julia> mats == @btime unzip(broadcast((x,y) -> (x+y, x-y), 1:1000, transpose(1:1000))) # intermediate matrix of tuples
1919
min 2.660 ms, mean 40.007 ms (6 allocations, 30.52 MiB)
2020
true
2121
```
2222
"""
23-
function tuplecast(f::F, args...) where {F}
23+
function unzip_broadcast(f::F, args...) where {F}
2424
T = Broadcast.combine_eltypes(f, args)
2525
if isconcretetype(T)
26-
T <: Tuple || throw(ArgumentError("""tuplecast(f, args) only works on functions returning a tuple,
26+
T <: Tuple || throw(ArgumentError("""unzip_broadcast(f, args) only works on functions returning a tuple,
2727
but f = $(sprint(show, f)) returns type T = $T"""))
2828
end
2929
# TODO allow GPU arrays, possibly just as a fallback unzip, but see also:
@@ -39,7 +39,7 @@ function tuplecast(f::F, args...) where {F}
3939
end
4040
end
4141

42-
function ChainRulesCore.rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(tuplecast), f::F, args...) where {F}
42+
function ChainRulesCore.rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(unzip_broadcast), f::F, args...) where {F}
4343
y, back = rrule_via_ad(cfg, broadcast, f, args...)
4444
z = unzip(y)
4545
function untuplecast(dz)
@@ -53,23 +53,25 @@ function ChainRulesCore.rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(tuplec
5353
end
5454

5555
# This is for testing, but the tests using it don't work.
56-
function rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(collecttuplecast), f, args...)
57-
y, back = rrule(cfg, tuplecast, f, args...)
56+
function rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(collectunzip_broadcast), f, args...)
57+
y, back = rrule(cfg, unzip_broadcast, f, args...)
5858
return collect(y), back
5959
end
6060

61+
#=
62+
6163
"""
62-
tuplemap(f, args...)
64+
unzip_map(f, args...)
6365
6466
For a function `f` which returns a tuple, this is `== unzip(map(f, args...))`,
6567
but performed using `StructArrays` for efficiency.
6668
67-
Not in use at present, but see `tuplecast`.
69+
Not in use at present, but see `unzip_broadcast`.
6870
"""
69-
function tuplemap(f::F, args...) where {F}
71+
function unzip_map(f::F, args...) where {F}
7072
T = Broadcast.combine_eltypes(f, args)
7173
if isconcretetype(T)
72-
T <: Tuple || throw(ArgumentError("""tuplemap(f, args) only works on functions returning a tuple,
74+
T <: Tuple || throw(ArgumentError("""unzip_map(f, args) only works on functions returning a tuple,
7375
but f = $(sprint(show, f)) returns type T = $T"""))
7476
end
7577
# if any(a -> a isa CuArray, args)
@@ -78,17 +80,19 @@ function tuplemap(f::F, args...) where {F}
7880
return StructArrays.components(StructArray(Iterators.map(f, args...)))
7981
end
8082
81-
function ChainRulesCore.rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(tuplemap), f::F, xs...) where {F}
83+
function ChainRulesCore.rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(unzip_map), f::F, xs...) where {F}
8284
y, back = rrule_via_ad(cfg, map, f, xs...)
8385
z = unzip(y)
84-
function untuplemap(dz)
86+
function ununzip_map(dz)
8587
# dy = StructArray(map(unthunk, dz)) # fails for e.g. StructArray(([1,2,3], ZeroTangent()))
8688
dy = broadcast(tuple, map(unthunk, dz)...)
8789
return back(dy)
8890
end
89-
return z, untuplemap
91+
return z, ununzip_map
9092
end
9193
94+
=#
95+
9296
"""
9397
unzip(A)
9498
@@ -114,7 +118,7 @@ function unzip(xs::AbstractArray)
114118
x1 = first(xs)
115119
x1 isa Tuple || throw(ArgumentError("unzip only accepts arrays of tuples"))
116120
N = length(x1)
117-
return unzip(xs, Val(N)) # like Zygote's unzip, here this is the fallback case.
121+
return unzip(xs, Val(N)) # like Zygote's unzip. Here this is the fallback case.
118122
end
119123

120124
@generated function unzip(xs, ::Val{N}) where {N}

test/runtests.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ end
5959
include_test("rulesets/Base/sort.jl")
6060
include_test("rulesets/Base/broadcast.jl")
6161

62-
include_test("tuplecast.jl") # used primarily for broadcast
62+
include_test("unzipped.jl") # used primarily for broadcast
6363

6464
println()
6565

test/tuplecast.jl renamed to test/unzipped.jl

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11

2-
using ChainRules: tuplecast, unzip, tuplemap
2+
using ChainRules: unzip_broadcast, unzip #, unzip_map
33

4-
@testset "tuplecast.jl" begin
5-
@testset "basics: $(sprint(show, fun))" for fun in [tuplemap, tuplecast, unzipmap, unzipbroadcast]
4+
@testset "unzip_broadcast.jl" begin
5+
@testset "basics: $(sprint(show, fun))" for fun in [unzip_broadcast, unzipmap, unzipbroadcast] # unzip_map,
66
@test_throws Exception fun(sqrt, 1:3)
77

88
@test fun(tuple, 1:3, 4:6) == ([1, 2, 3], [4, 5, 6])
@@ -17,7 +17,7 @@ using ChainRules: tuplecast, unzip, tuplemap
1717
@test fun(tuple, [1,2,3], [4 5]) == ([1 1; 2 2; 3 3], [4 5; 4 5; 4 5])
1818
end
1919

20-
if fun == tuplemap
20+
if fun == unzip_map
2121
@test_broken fun(tuple, (1,2,3), (4,5,6)) == ((1, 2, 3), (4, 5, 6))
2222
elseif fun == unzipmap
2323
@test fun(tuple, (1,2,3), (4,5,6)) == ((1, 2, 3), (4, 5, 6))
@@ -32,19 +32,19 @@ using ChainRules: tuplecast, unzip, tuplemap
3232
@testset "rrules" begin
3333
# These exist to allow for second derivatives
3434

35-
# test_rrule(collect∘tuplecast, tuple, [1,2,3.], [4,5,6.], collectheck_inferred=false) # return type Tuple{NoTangent, NoTangent, Vector{Float64}, Vector{Float64}} does not match inferred return type NTuple{4, Any}
35+
# test_rrule(collect∘unzip_broadcast, tuple, [1,2,3.], [4,5,6.], collectheck_inferred=false) # return type Tuple{NoTangent, NoTangent, Vector{Float64}, Vector{Float64}} does not match inferred return type NTuple{4, Any}
3636

37-
y1, bk1 = rrule(CFG, tuplecast, tuple, [1,2,3.0], [4,5,6.0])
37+
y1, bk1 = rrule(CFG, unzip_broadcast, tuple, [1,2,3.0], [4,5,6.0])
3838
@test y1 == ([1, 2, 3], [4, 5, 6])
3939
@test bk1(([1,10,100.0], [7,8,9.0]))[3] [1,10,100]
4040

4141
# bk1(([1,10,100.0], NoTangent())) # DimensionMismatch in FiniteDifferences
4242

43-
y2, bk2 = rrule(CFG, tuplecast, tuple, [1,2,3.0], [4 5.0], 6.0)
43+
y2, bk2 = rrule(CFG, unzip_broadcast, tuple, [1,2,3.0], [4 5.0], 6.0)
4444
@test y2 == ([1 1; 2 2; 3 3], [4 5; 4 5; 4 5], [6 6; 6 6; 6 6])
4545
@test bk2(y2)[5] 36
4646

47-
y4, bk4 = rrule(CFG, tuplemap, tuple, [1,2,3.0], [4,5,6.0])
47+
y4, bk4 = rrule(CFG, unzip_map, tuple, [1,2,3.0], [4,5,6.0])
4848
@test y4 == ([1, 2, 3], [4, 5, 6])
4949
@test bk4(([1,10,100.0], [7,8,9.0]))[3] [1,10,100]
5050
end

0 commit comments

Comments
 (0)