Skip to content

Commit 8073c7c

Browse files
mcabbottoxinabox
andauthored
Rules for getindex(::Tuple) and sum(::Tuple) (#643)
* getindex for tuples * sum for tuples * repeated indices in getindex, etc * add colon case * tidy, use Tanget * first, tail * simplify * Apply 2 suggestions Co-authored-by: Frames Catherine White <oxinabox@ucc.asn.au> * skip a test until JuliaDiff/ChainRulesTestUtils.jl#253, bump version * comment on Zygote Co-authored-by: Frames Catherine White <oxinabox@ucc.asn.au>
1 parent dadb205 commit 8073c7c

File tree

5 files changed

+133
-1
lines changed

5 files changed

+133
-1
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "ChainRules"
22
uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2"
3-
version = "1.37.0"
3+
version = "1.38.0"
44

55
[deps]
66
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"

src/rulesets/Base/indexing.jl

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,58 @@
1+
#####
2+
##### getindex(::Tuple)
3+
#####
4+
5+
function frule((_, ẋ), ::typeof(getindex), x::Tuple, i::Integer)
6+
return x[i], ẋ[i]
7+
end
8+
9+
function frule((_, ẋ), ::typeof(getindex), x::Tuple, i)
10+
y = x[i]
11+
return y, Tangent{typeof(y)}(ẋ[i]...)
12+
end
13+
14+
"for a given typle type, returns a Val{N} where N is the length of the tuple"
15+
_tuple_N(::Type{<:Tuple{Vararg{<:Any, N}}}) where {N} = Val(N)
16+
17+
function rrule(::typeof(getindex), x::T, i::Integer) where {T<:Tuple}
18+
function getindex_back_1(dy)
19+
dx = ntuple(j -> j == i ? dy : NoTangent(), _tuple_N(T))
20+
return (NoTangent(), Tangent{T}(dx...), NoTangent())
21+
end
22+
return x[i], getindex_back_1
23+
end
24+
25+
# Special case for tuples of only numbers
26+
function rrule(::typeof(getindex), x::T, i::Integer) where {T<:NTuple{<:Any,<:Number}}
27+
function getindex_back_2(dy_raw)
28+
dy = unthunk(dy_raw)
29+
dx = ntuple(j -> j == i ? dy : zero(dy), _tuple_N(T))
30+
return (NoTangent(), Tangent{T}(dx...), NoTangent())
31+
end
32+
return x[i], getindex_back_2
33+
end
34+
35+
# Note Zygote has getindex(::Tuple, ::UnitRange) separately from getindex(::Tuple, ::AbstractVector),
36+
# whether that's more efficient has not been investigated here.
37+
# https://github.com/FluxML/Zygote.jl/blob/master/src/lib/lib.jl#L125-L142
38+
function rrule(::typeof(getindex), x::T, inds) where {T<:Tuple} # e.g. ranges, not type-stable
39+
function getindex_back_3(dy_raw)
40+
dy = unthunk(dy_raw)
41+
dx = ntuple(Returns(NoTangent()), _tuple_N(T))
42+
for (dyi, i) in zip(dy, inds)
43+
dx = Base.setindex(dx, dyi + dx[i], i)
44+
end
45+
return (NoTangent(), Tangent{T}(dx...), NoTangent())
46+
end
47+
return x[inds], getindex_back_3
48+
end
49+
50+
function rrule(::typeof(getindex), x::Tuple, ::Colon)
51+
getindex_back_4(dy) = (NoTangent(), dy, NoTangent())
52+
return x, getindex_back_4
53+
end
54+
55+
156
#####
257
##### getindex
358
#####
@@ -31,6 +86,29 @@ function rrule(::typeof(getindex), x::Array{<:Number}, inds...)
3186
return y, getindex_pullback
3287
end
3388

89+
#####
90+
##### first, tail
91+
#####
92+
93+
function frule((_, ẋ), ::typeof(first), x::Tuple)
94+
return first(x), first(ẋ)
95+
end
96+
97+
function rrule(::typeof(first), x::T) where {T<:Tuple}
98+
first_back(dy) = (NoTangent(), Tangent{T}(ntuple(j -> j == 1 ? dy : NoTangent(), _tuple_N(T))...))
99+
return first(x), first_back
100+
end
101+
102+
function frule((_, ẋ), ::typeof(Base.tail), x::Tuple)
103+
y = Base.tail(x)
104+
return y, Tangent{typeof(y)}(Base.tail(Tuple(ẋ))...)
105+
end
106+
107+
function rrule(::typeof(Base.tail), x::T) where {T<:Tuple}
108+
tail_pullback(dy) = (NoTangent(), Tangent{T}(NoTangent(), dy...))
109+
return Base.tail(x), tail_pullback
110+
end
111+
34112
#####
35113
##### view
36114
#####

src/rulesets/Base/mapreduce.jl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,17 @@ function frule((_, ẏ, ẋ), ::typeof(sum!), y::AbstractArray, x::AbstractArray
1414
return sum!(y, x), sum!(ẏ, ẋ)
1515
end
1616

17+
function rrule(::typeof(sum), x::Tuple)
18+
project = ProjectTo(x)
19+
len = Val(length(x))
20+
function sum_pullback(dy_raw)
21+
dy = unthunk(dy_raw)
22+
dx = dy isa AbstractZero ? dy : ntuple(Returns(dy), len)
23+
return (NoTangent(), project(dx))
24+
end
25+
return sum(x), sum_pullback
26+
end
27+
1728
function rrule(::typeof(sum), x::AbstractArray; dims=:)
1829
project = ProjectTo(x)
1930
y = sum(x; dims=dims)

test/rulesets/Base/indexing.jl

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,25 @@
11
@testset "getindex" begin
2+
@testset "getindex(::Tuple, ...)" begin
3+
x = (1.2, 3.4, 5.6)
4+
x2 = (rand(2), (a=1.0, b=x))
5+
6+
# Forward
7+
test_frule(getindex, x, 2)
8+
test_frule(getindex, x2, 1)
9+
test_frule(getindex, x, 1:2)
10+
test_frule(getindex, x2, :)
11+
12+
# Reverse
13+
test_rrule(getindex, x, 2)
14+
@test_skip test_rrule(getindex, x2, 1, check_inferred=false) # method ambiguity, maybe fixed by https://github.com/JuliaDiff/ChainRulesTestUtils.jl/pull/253
15+
16+
test_rrule(getindex, x, 2:3; check_inferred=false)
17+
test_rrule(getindex, x, [1, 1, 2], check_inferred=false)
18+
test_rrule(getindex, x2, 1:2, check_inferred=false)
19+
20+
test_rrule(getindex, x, :)
21+
end
22+
223
@testset "getindex(::Matrix{<:Number}, ...)" begin
324
x = [1.0 2.0 3.0; 10.0 20.0 30.0]
425

@@ -58,6 +79,23 @@
5879
end
5980
end
6081

82+
@testset "first & tail" begin
83+
x = (1.2, 3.4, 5.6)
84+
x2 = (rand(2), (a=1.0, b=x))
85+
86+
test_frule(first, x)
87+
test_frule(first, x2)
88+
89+
test_rrule(first, x)
90+
# test_rrule(first, x2) # MethodError: (::ChainRulesTestUtils.var"#test_approx##kw")(::NamedTuple{(:rtol, :atol), Tuple{Float64, Float64}}, ::typeof(test_approx), ::NoTangent, ::Tangent{NamedTuple{(:a, :b), Tuple{Float64, Tuple{Float64, Float64, Float64}}}, NamedTuple{(:a, :b), Tuple{Float64, Tangent{Tuple{Float64, Float64, Float64}, Tuple{Float64, Float64, Float64}}}}}, ::String) is ambiguous
91+
92+
test_frule(Base.tail, x, check_inferred=false) # return type Tuple{Tuple{Float64, Float64}, Tangent{Tuple{Float64, Float64}, Tuple{Float64, Float64}}} does not match inferred return type Tuple{Tuple{Float64, Float64}, Tangent{Tuple{Float64, Float64}}}
93+
test_frule(Base.tail, x2, check_inferred=false)
94+
95+
test_rrule(Base.tail, x)
96+
test_rrule(Base.tail, x2)
97+
end
98+
6199
@testset "view" begin
62100
test_frule(view, rand(3, 4), :, 1)
63101
test_frule(view, rand(3, 4), 2, [1, 1, 2])

test/rulesets/Base/mapreduce.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,11 @@ const CFG = ChainRulesTestUtils.ADviaRuleConfig()
77
@testset "Reductions" begin
88
@testset "sum(::Tuple)" begin
99
test_frule(sum, Tuple(rand(5)))
10+
test_frule(sum, (rand(2), rand(2)))
11+
12+
test_rrule(sum, Tuple(rand(5)))
13+
test_rrule(sum, (1.2, 3.4 + 5im))
14+
test_rrule(sum, (rand(2)', rand(1,2)))
1015
end
1116
@testset "sum(x; dims=$dims)" for dims in (:, 2, (1,3))
1217
# Forward

0 commit comments

Comments
 (0)