Skip to content

Commit 05ebb38

Browse files
Kenooscardssmith
andauthored
Remove frule for getindex(::Tuple, i) (#680)
* Remove frule for getindex(::Tuple, i) Having this chain rule is sub-optimal, because it prevents early-SROA in Diffractor-like systems that would like to perform some optimizations before applying AD (but can't do any optimization on functions that have custom rules). By letting it go down to the `getfield`, regular SROA can apply. Any AD system should handle `getfield` anyway, so I don't think there's a strong reason to have this. Similar reasoning applies to the reverse rules also, but they aren't currently actively causing me problems, so this PR only removes the frule, since I don't think many other packages are using them. We can revisit the rrules later. * Also remove the rules for first/tail For similar reasons as getindex, having a rule for first/tail is suboptimal because it supresses early SROA. Tail is particularly problematic, because it is used in the implementation of the ``` x, y... = abc ``` syntax, of which users expect early elimination. * add getfield rule and remove tests for deleted rules --------- Co-authored-by: Oscar Smith <oscardssmith@gmail.com>
1 parent 859f6ab commit 05ebb38

File tree

2 files changed

+6
-57
lines changed

2 files changed

+6
-57
lines changed

src/rulesets/Base/indexing.jl

Lines changed: 4 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,6 @@
1-
#####
2-
##### getindex(::Tuple)
3-
#####
4-
5-
function frule((_, ẋ), ::typeof(getindex), x::Tuple, i::Integer)
6-
return x[i], ẋ[i]
7-
end
8-
9-
function frule((_, ẋ), ::typeof(getindex), x::Tuple, i)
10-
y = x[i]
11-
return y, Tangent{typeof(y)}(ẋ[i]...)
1+
# Int rather than Int64/Integer is intentional
2+
function frule((_, ẋ), ::typeof(getfield), x::Tuple, i::Int)
3+
return x.i, ẋ.i
124
end
135

146
"for a given tuple type, returns a Val{N} where N is the length of the tuple"
@@ -77,7 +69,7 @@ end
7769
"""
7870
∇getindex(x, dy, inds...)
7971
80-
For the `rrule` of `y = x[inds...]`, this function is roughly
72+
For the `rrule` of `y = x[inds...]`, this function is roughly
8173
`setindex(zero(x), dy, inds...)`, returning the array `dx`.
8274
Differentiable. Includes `ProjectTo(x)(dx)`.
8375
"""
@@ -191,29 +183,6 @@ function ∇getindex!(dx::AbstractGPUArray, dy, inds...)
191183
return dx
192184
end
193185

194-
#####
195-
##### first, tail
196-
#####
197-
198-
function frule((_, ẋ), ::typeof(first), x::Tuple)
199-
return first(x), first(ẋ)
200-
end
201-
202-
function rrule(::typeof(first), x::T) where {T<:Tuple}
203-
first_back(dy) = (NoTangent(), Tangent{T}(ntuple(j -> j == 1 ? dy : NoTangent(), _tuple_N(T))...))
204-
return first(x), first_back
205-
end
206-
207-
function frule((_, ẋ), ::typeof(Base.tail), x::Tuple)
208-
y = Base.tail(x)
209-
return y, Tangent{typeof(y)}(Base.tail(ẋ)...)
210-
end
211-
212-
function rrule(::typeof(Base.tail), x::T) where {T<:Tuple}
213-
tail_pullback(dy) = (NoTangent(), Tangent{T}(NoTangent(), dy...))
214-
return Base.tail(x), tail_pullback
215-
end
216-
217186
#####
218187
##### view
219188
#####

test/rulesets/Base/indexing.jl

Lines changed: 2 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,7 @@
33
x = (1.2, 3.4, 5.6)
44
x2 = (rand(2), (a=1.0, b=x))
55

6-
# Forward
7-
test_frule(getindex, x, 2)
8-
test_frule(getindex, x2, 1)
9-
test_frule(getindex, x, 1:2)
10-
test_frule(getindex, x2, :)
11-
6+
# don't test Forward because this will be handled by lowering to getfield
127
# Reverse
138
test_rrule(getindex, x, 2)
149
@test_skip test_rrule(getindex, x2, 1, check_inferred=false) # method ambiguity, maybe fixed by https://github.com/JuliaDiff/ChainRulesTestUtils.jl/pull/253
@@ -168,22 +163,7 @@
168163
end
169164
end
170165

171-
@testset "first & tail" begin
172-
x = (1.2, 3.4, 5.6)
173-
x2 = (rand(2), (a=1.0, b=x))
174-
175-
test_frule(first, x)
176-
test_frule(first, x2)
177-
178-
test_rrule(first, x)
179-
# test_rrule(first, x2) # MethodError: (::ChainRulesTestUtils.var"#test_approx##kw")(::NamedTuple{(:rtol, :atol), Tuple{Float64, Float64}}, ::typeof(test_approx), ::NoTangent, ::Tangent{NamedTuple{(:a, :b), Tuple{Float64, Tuple{Float64, Float64, Float64}}}, NamedTuple{(:a, :b), Tuple{Float64, Tangent{Tuple{Float64, Float64, Float64}, Tuple{Float64, Float64, Float64}}}}}, ::String) is ambiguous
180-
181-
test_frule(Base.tail, x, check_inferred=false) # return type Tuple{Tuple{Float64, Float64}, Tangent{Tuple{Float64, Float64}, Tuple{Float64, Float64}}} does not match inferred return type Tuple{Tuple{Float64, Float64}, Tangent{Tuple{Float64, Float64}}}
182-
test_frule(Base.tail, x2, check_inferred=false)
183-
184-
test_rrule(Base.tail, x)
185-
test_rrule(Base.tail, x2)
186-
end
166+
# first & tail handled by getfield rules
187167

188168
@testset "view" begin
189169
test_frule(view, rand(3, 4), :, 1)

0 commit comments

Comments
 (0)