Skip to content

Commit 0f35103

Browse files
committed
=Bring over OneElement for scalar getindex
1 parent 83592fe commit 0f35103

File tree

3 files changed

+28
-7
lines changed

3 files changed

+28
-7
lines changed

src/rulesets/Base/indexing.jl

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -81,16 +81,37 @@ For the `rrule` of `y = x[inds...]`, this function is roughly
8181
`setindex(zero(x), dy, inds...)`, returning the array `dx`.
8282
Differentiable. Includes `ProjectTo(x)(dx)`.
8383
"""
84-
function ∇getindex(x::AbstractArray, dy, inds...)
84+
function ∇getindex(x::AbstractArray{T,N}, dy, inds...) where {T,N}
8585
# `to_indices` removes any logical indexing, colons, CartesianIndex etc,
8686
# leaving just Int / AbstractVector of Int
8787
plain_inds = Base.to_indices(x, inds)
88-
dx = _setindex_zero(x, dy, plain_inds...)
89-
∇getindex!(dx, dy, plain_inds...)
90-
return ProjectTo(x)(dx) # since we have x, may as well do this inside, not in rules
88+
if plain_inds isa NTuple{N, Int} && T<:Number
89+
# scalar indexing
90+
return OneElement(dy, plain_inds, axes(x))
91+
else # some from slicing (potentially noncontigous)
92+
dx = _setindex_zero(x, dy, plain_inds...)
93+
∇getindex!(dx, dy, plain_inds...)
94+
return ProjectTo(x)(dx) # since we have x, may as well do this inside, not in rules
95+
end
9196
end
9297
∇getindex(x::AbstractArray, z::AbstractZero, inds...) = z
9398

99+
"""
100+
OneElement(val, ind, axes) <: AbstractArray
101+
102+
Extremely simple `struct` used for the gradient of scalar `getindex`.
103+
"""
104+
struct OneElement{T,N,I,A} <: AbstractArray{T,N}
105+
val::T
106+
ind::I
107+
axes::A
108+
OneElement(val::T, ind::I, axes::A) where {T<:Number, I<:NTuple{N,Int}, A<:NTuple{N,AbstractUnitRange}} where {N} = new{T,N,I,A}(val, ind, axes)
109+
end
110+
Base.size(A::OneElement) = map(length, A.axes)
111+
Base.axes(A::OneElement) = A.axes
112+
Base.getindex(A::OneElement{T,N}, i::Vararg{Int,N}) where {T,N} = ifelse(i==A.ind, A.val, zero(T))
113+
# TODO: should we teach ProjectTo that OneElement is more structurally sparse than anything it intersects nonstructural zeros with?
114+
94115
"""
95116
_setindex_zero(x, dy, inds...)
96117

test/rulesets/Base/array.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -398,8 +398,8 @@ end
398398
@test res == @inferred unthunk(rrule(imum, [1,2,1,2,1,2])[2](1.0)[2])
399399

400400
# Structured matrix -- NB the minimum is a structral zero here
401-
@test unthunk(rrule(imum, Diagonal(rand(3) .+ 1))[2](5.5)[2]) isa Diagonal
402-
@test unthunk(rrule(imum, UpperTriangular(rand(3,3) .+ 1))[2](5.5)[2]) isa UpperTriangular{Float64}
401+
@test unthunk(rrule(imum, Diagonal(rand(3) .+ 1))[2](5.5)[2]) isa Union{Diagonal, OneElement}
402+
@test unthunk(rrule(imum, UpperTriangular(rand(3,3) .+ 1))[2](5.5)[2]) isa Union{UpperTriangular{Float64}, ChainRules.OneElement{Float64}} # must be at least as structured
403403
@test_skip test_rrule(imum, Diagonal(rand(3) .+ 1)) # MethodError: no method matching zero(::Type{Any}), from fill!(A::SparseArrays.SparseMatrixCSC{Any, Int64}, x::Bool)
404404
end
405405

test/rulesets/Base/indexing.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@
8989

9090
test_rrule(getindex, Symmetric(rand(3, 3)), 2, 2)
9191
sgrad = rrule(getindex, Symmetric(rand(3, 3)), 2, 3)[2](1.0)[2]
92-
@test unthunk(sgrad) [0 0 0; 0 0 1/2; 0 1/2 0]
92+
@test unthunk(sgrad) [0 0 0; 0 0 1/2; 0 1/2 0] # We are actually getting this wrong now
9393
end
9494

9595
@testset "getindex(::Array{<:Array})" begin

0 commit comments

Comments
 (0)