Skip to content

Commit ec41724

Browse files
committed
fixup tuplecast
1 parent 8498212 commit ec41724

File tree

2 files changed

+114
-40
lines changed

2 files changed

+114
-40
lines changed

src/tuplecast.jl

Lines changed: 61 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -26,28 +26,36 @@ function tuplecast(f::F, args...) where {F}
2626
T <: Tuple || throw(ArgumentError("""tuplecast(f, args) only works on functions returning a tuple,
2727
but f = $(sprint(show, f)) returns type T = $T"""))
2828
end
29+
# TODO allow GPU arrays, possibly just as a fallback unzip, but see also:
30+
# https://github.com/JuliaArrays/StructArrays.jl/issues/150
2931
# if any(a -> a isa CuArray, args)
3032
# return unzip(broadcast(f, args...))
3133
# end
3234
bc = Broadcast.instantiate(Broadcast.broadcasted(f, args...))
33-
StructArrays.components(StructArray(bc))
35+
if Broadcast.BroadcastStyle(typeof(bc)) isa Broadcast.AbstractArrayStyle
36+
return StructArrays.components(StructArray(bc))
37+
else
38+
return unzip(broadcast(f, args...)) # e.g. tuples
39+
end
3440
end
3541

3642
function ChainRulesCore.rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(tuplecast), f::F, args...) where {F}
37-
y, back = rrule_via_ad(cfg, broadcasted, f, args...)
43+
y, back = rrule_via_ad(cfg, broadcast, f, args...)
3844
z = unzip(y)
3945
function untuplecast(dz)
40-
dy = StructArray(map(unthunk, dz))
46+
# dy = StructArray(map(unthunk, dz)) # fails for e.g. StructArray(([1,2,3], ZeroTangent()))
47+
dy = broadcast(tuple, map(unthunk, dz)...)
4148
db, df, dargs... = back(dy)
42-
(db, sum(df), map(unbroadcast, args, dargs)...)
49+
return (db, sum(df), map(unbroadcast, args, dargs)...)
4350
end
51+
untuplecast(dz::AbstractZero) = (NoTangent(), NoTangent(), map(Returns(dz), args))
4452
return z, untuplecast
4553
end
4654

47-
# function rrule(cfg::RCR, ::typeof(collect∘tuplecast), f, args...)
48-
# y, back = rrule(cfg, tuplecast, f, args...)
49-
# return collect(y), back
50-
# end
55+
function rrule(cfg::RCR, ::typeof(collecttuplecast), f, args...) # for testing, but doesn't work?
56+
y, back = rrule(cfg, tuplecast, f, args...)
57+
return collect(y), back
58+
end
5159

5260
"""
5361
tuplemap(f, args...)
@@ -64,18 +72,19 @@ function tuplemap(f::F, args...) where {F}
6472
# if any(a -> a isa CuArray, args)
6573
# return unzip(map(f, args...))
6674
# end
67-
StructArrays.components(StructArray(Iterators.map(f, args...)))
75+
return StructArrays.components(StructArray(Iterators.map(f, args...)))
6876
end
6977

70-
# function ChainRulesCore.rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(tuplemap), f::F, args...) where {F}
71-
# y, back = rrule(cfg, map, f, xs...) # won't work, but also, you want the lazier fwd
72-
# z = unzip(y)
73-
# function untuplemap(dz)
74-
# dy = StructArray(map(unthunk, dz))
75-
# back(dy)
76-
# end
77-
# return unzip(xs), untuplemap
78-
# end
78+
function ChainRulesCore.rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(tuplemap), f::F, xs...) where {F}
79+
y, back = rrule_via_ad(cfg, map, f, xs...)
80+
z = unzip(y)
81+
function untuplemap(dz)
82+
# dy = StructArray(map(unthunk, dz)) # fails for e.g. StructArray(([1,2,3], ZeroTangent()))
83+
dy = broadcast(tuple, map(unthunk, dz)...)
84+
return back(dy)
85+
end
86+
return z, untuplemap
87+
end
7988

8089
"""
8190
unzip(A)
@@ -84,8 +93,8 @@ Converts an array of tuples into a tuple of arrays.
8493
Eager. Will work by `reinterpret` when possible.
8594
8695
```jldoctest
87-
julia> ChainRules.unzip([(1,2), (3,4), (5,6)]) # makes two new Arrays:
88-
([1, 3, 5], [2, 4, 6])
96+
julia> ChainRules.unzip([(1,2), (30,40), (500,600)]) # makes two new Arrays:
97+
([1, 30, 500], [2, 40, 600])
8998
9099
julia> typeof(ans)
91100
Tuple{Vector{Int64}, Vector{Int64}}
@@ -102,7 +111,7 @@ function unzip(xs::AbstractArray)
102111
x1 = first(xs)
103112
x1 isa Tuple || throw(ArgumentError("unzip only accepts arrays of tuples"))
104113
N = length(x1)
105-
unzip(xs, Val(N)) # like Zygote's unzip, here this is the fallback case.
114+
return unzip(xs, Val(N)) # like Zygote's unzip, here this is the fallback case.
106115
end
107116

108117
@generated function unzip(xs, ::Val{N}) where {N}
@@ -122,16 +131,44 @@ unzip(xs::AbstractArray{Tuple{T}}) where {T} = (reinterpret(T, xs),) # best cas
122131
Expr(:tuple, each...)
123132
end
124133

134+
"""
135+
unzip(t)
136+
137+
Also works on a tuple of tuples:
138+
139+
```jldoctest
140+
julia> unzip(((1,2), (30,40), (500,600)))
141+
((1, 30, 500), (2, 40, 600))
142+
```
143+
"""
144+
function unzip(xs::Tuple)
145+
x1 = first(xs)
146+
x1 isa Tuple || throw(ArgumentError("unzip only accepts arrays or tuples of tuples"))
147+
return ntuple(i -> map(Get(i), xs),length(x1))
148+
end
149+
125150
struct Get{i} end
126151
Get(i) = Get{Int(i)}()
127152
(::Get{i})(x) where {i} = x[i]
128153

129154
function ChainRulesCore.rrule(::typeof(unzip), xs::AbstractArray{T}) where {T <: Tuple}
130155
function rezip(dy)
131-
dxs = map(unthunk.(dy)...) do ys...
132-
Tangent{T}(ys...)
156+
dxs = broadcast(xs, unthunk.(dy)...) do x, ys...
157+
ProjectTo(x)(Tangent{T}(ys...))
133158
end
134-
(NoTangent(), dxs)
159+
return (NoTangent(), dxs)
135160
end
161+
rezip(dz::AbstractZero) = (NoTangent(), dz)
136162
return unzip(xs), rezip
137163
end
164+
165+
function ChainRulesCore.rrule(::typeof(unzip), xs::Tuple)
166+
function rezip_2(dy)
167+
dxs = broadcast(xs, unthunk.(dy)...) do x, ys...
168+
Tangent{typeof(x)}(ys...)
169+
end
170+
return (NoTangent(), ProjectTo(xs)(dxs))
171+
end
172+
rezip_2(dz::AbstractZero) = (NoTangent(), dz)
173+
return unzip(xs), rezip_2
174+
end

test/tuplecast.jl

Lines changed: 53 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11

2-
using ChainRules: tuplecast, unzip # tuplemap,
2+
using ChainRules: tuplecast, unzip, tuplemap
33

44
@testset "tuplecast.jl" begin
5-
@testset "basics: $(sprint(show, fun))" for fun in [tuplecast, unzipbroadcast] # [tuplemap, tuplecast, unzip∘map, unzip∘broadcast]
5+
@testset "basics: $(sprint(show, fun))" for fun in [tuplemap, tuplecast, unzipmap, unzipbroadcast]
66
@test_throws Exception fun(sqrt, 1:3)
77

88
@test fun(tuple, 1:3, 4:6) == ([1, 2, 3], [4, 5, 6])
@@ -16,32 +16,69 @@ using ChainRules: tuplecast, unzip # tuplemap,
1616
else
1717
@test fun(tuple, [1,2,3], [4 5]) == ([1 1; 2 2; 3 3], [4 5; 4 5; 4 5])
1818
end
19+
20+
if fun == tuplemap
21+
@test_broken fun(tuple, (1,2,3), (4,5,6)) == ((1, 2, 3), (4, 5, 6))
22+
elseif fun == unzipmap
23+
@test fun(tuple, (1,2,3), (4,5,6)) == ((1, 2, 3), (4, 5, 6))
24+
else
25+
@test fun(tuple, (1,2,3), (4,5,6)) == ((1, 2, 3), (4, 5, 6))
26+
@test fun(tuple, (1,2,3), (7,)) == ((1, 2, 3), (7, 7, 7))
27+
@test fun(tuple, (1,2,3), 8) == ((1, 2, 3), (8, 8, 8))
28+
end
29+
@test fun(tuple, (1,2,3), [4,5,6]) == ([1, 2, 3], [4, 5, 6]) # mix tuple & vector
1930
end
31+
32+
@testset "rrules" begin
33+
# These exist to allow for second derivatives
2034

21-
# tuplemap(tuple, (1,2,3), (4,5,6)) == ([1, 2, 3], [4, 5, 6])
35+
# test_rrule(collect∘tuplecast, tuple, [1,2,3.], [4,5,6.], collectheck_inferred=false) # return type Tuple{NoTangent, NoTangent, Vector{Float64}, Vector{Float64}} does not match inferred return type NTuple{4, Any}
36+
37+
y1, bk1 = rrule(CFG, tuplecast, tuple, [1,2,3.0], [4,5,6.0])
38+
@test y1 == ([1, 2, 3], [4, 5, 6])
39+
@test bk1(([1,10,100.0], [7,8,9.0]))[3] [1,10,100]
40+
41+
# bk1(([1,10,100.0], NoTangent())) # DimensionMismatch in FiniteDifferences
42+
43+
y2, bk2 = rrule(CFG, tuplecast, tuple, [1,2,3.0], [4 5.0], 6.0)
44+
@test y2 == ([1 1; 2 2; 3 3], [4 5; 4 5; 4 5], [6 6; 6 6; 6 6])
45+
@test bk2(y2)[5] 36
2246

47+
y4, bk4 = rrule(CFG, tuplemap, tuple, [1,2,3.0], [4,5,6.0])
48+
@test y4 == ([1, 2, 3], [4, 5, 6])
49+
@test bk4(([1,10,100.0], [7,8,9.0]))[3] [1,10,100]
50+
end
51+
2352
@testset "unzip" begin
2453
@test unzip([(1,2), (3,4), (5,6)]) == ([1, 3, 5], [2, 4, 6])
54+
@test unzip(Any[(1,2), (3,4), (5,6)]) == ([1, 3, 5], [2, 4, 6])
55+
2556
@test unzip([(nothing,2), (3,4), (5,6)]) == ([nothing, 3, 5], [2, 4, 6])
2657
@test unzip([(missing,2), (missing,4), (missing,6)])[2] isa Base.ReinterpretArray
2758

59+
@test unzip([(1,), (3,), (5,)]) == ([1, 3, 5],)
60+
@test unzip([(1,), (3,), (5,)])[1] isa Base.ReinterpretArray
61+
62+
@test unzip(((1,2), (3,4), (5,6))) == ((1, 3, 5), (2, 4, 6))
63+
64+
# test_rrule(unzip, [(1,2), (3,4), (5.0,6.0)], check_inferred=false) # DimensionMismatch: second dimension of A, 6, does not match length of x, 2
65+
2866
y, bk = rrule(unzip, [(1,2), (3,4), (5,6)])
2967
@test y == ([1, 3, 5], [2, 4, 6])
3068
@test bk(Tangent{Tuple}([1,1,1], [10,100,1000]))[2] isa Vector{<:Tangent{<:Tuple}}
31-
end
32-
33-
@testset "rrules" begin
34-
# These exist to allow for second derivatives
3569

36-
# test_rrule(collect∘tuplecast, tuple, [1,2,3.], [4,5,6.], check_inferred=false)
37-
y1, bk1 = rrule(CFG, tuplecast, tuple, [1,2,3.0], [4,5,6.0])
38-
@test y1 == ([1, 2, 3], [4, 5, 6])
39-
@test bk1(([1,10,100.0], [7,8,9.0]))[3] [1,10,100]
70+
y3, bk3 = rrule(unzip, [(1,ZeroTangent()), (3,ZeroTangent()), (5,ZeroTangent())])
71+
@test y3 == ([1, 3, 5], [ZeroTangent(), ZeroTangent(), ZeroTangent()])
72+
dx3 = bk3(Tangent{Tuple}([1,1,1], [10,100,1000]))[2]
73+
@test dx3 isa Vector{<:Tangent{<:Tuple}}
74+
@test Tuple(dx3[1]) == (1.0, NoTangent())
4075

41-
y2, bk2 = rrule(CFG, tuplecast, tuple, [1,2,3.0], [4 5.0], 6.0)
42-
@test y2 == ([1 1; 2 2; 3 3], [4 5; 4 5; 4 5], [6 6; 6 6; 6 6])
43-
@test bk2(y2)[5] 36
44-
45-
test_rrule(unzip, [(1.0, 2.0), (3.0, 4.0), (5.0, 6.0)], check_inferred=false)
76+
y5, bk5 = rrule(unzip, ((1,2), (3,4), (5,6)))
77+
@test y5 == ((1, 3, 5), (2, 4, 6))
78+
@test bk5(y5)[2] isa Tangent{<:Tuple}
79+
@test Tuple(bk5(y5)[2][2]) == (3, 4)
80+
dx5 = bk5(((1,10,100), ZeroTangent()))
81+
@test dx5[2] isa Tangent{<:Tuple}
82+
@test Tuple(dx5[2][2]) == (10, ZeroTangent())
4683
end
4784
end

0 commit comments

Comments
 (0)