Skip to content

Commit c50f1c8

Browse files
TheBBc42f
authored andcommitted
Enable linear indexing with multidimensional index (#644)
1 parent 9ffa695 commit c50f1c8

File tree

2 files changed

+33
-18
lines changed

2 files changed

+33
-18
lines changed

src/indexing.jl

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -112,15 +112,15 @@ end
112112
end
113113
end
114114

115-
@propagate_inbounds function getindex(a::StaticArray, inds::StaticVector{<:Any, Int})
116-
_getindex(a, Length(inds), inds)
115+
@propagate_inbounds function getindex(a::StaticArray, inds::StaticArray{<:Tuple, Int})
116+
_getindex(a, Size(inds), inds)
117117
end
118118

119-
@generated function _getindex(a::StaticArray, ::Length{L}, inds::StaticVector{<:Any, Int}) where {L}
120-
exprs = [:(a[inds[$i]]) for i = 1:L]
119+
@generated function _getindex(a::StaticArray, s::Size{S}, inds::StaticArray{<:Tuple, Int}) where {S}
120+
exprs = [:(a[inds[$i]]) for i = 1:prod(S)]
121121
return quote
122122
@_propagate_inbounds_meta
123-
similar_type(a, Size(L))(tuple($(exprs...)))
123+
similar_type(a, s)(tuple($(exprs...)))
124124
end
125125
end
126126

@@ -159,36 +159,36 @@ end
159159
end
160160
end
161161

162-
@propagate_inbounds function setindex!(a::StaticArray, v, inds::StaticVector{<:Any, Int})
163-
_setindex!(a, v, Length(inds), inds)
162+
@propagate_inbounds function setindex!(a::StaticArray, v, inds::StaticArray{<:Tuple, Int})
163+
_setindex!(a, v, Size(inds), inds)
164164
return v
165165
end
166166

167-
@generated function _setindex!(a::StaticArray, v, ::Length{L}, inds::StaticVector{<:Any, Int}) where {L}
168-
exprs = [:(a[inds[$i]] = v) for i = 1:L]
167+
@generated function _setindex!(a::StaticArray, v, s::Size{S}, inds::StaticArray{<:Tuple, Int}) where {S}
168+
exprs = [:(a[inds[$i]] = v) for i = 1:prod(S)]
169169
return quote
170170
@_propagate_inbounds_meta
171-
similar_type(a, Size(L))(tuple($(exprs...)))
171+
similar_type(a, s)(tuple($(exprs...)))
172172
end
173173
end
174174

175-
@generated function _setindex!(a::StaticArray, v::AbstractArray, ::Length{L}, inds::StaticVector{<:Any, Int}) where {L}
176-
exprs = [:(a[inds[$i]] = v[$i]) for i = 1:L]
175+
@generated function _setindex!(a::StaticArray, v::AbstractArray, s::Size{S}, inds::StaticArray{<:Tuple, Int}) where {S}
176+
exprs = [:(a[inds[$i]] = v[$i]) for i = 1:prod(S)]
177177
return quote
178178
@_propagate_inbounds_meta
179-
if length(v) != L
180-
throw(DimensionMismatch("tried to assign $(length(v))-element array to length-$L destination"))
179+
if length(v) != $(prod(S))
180+
throw(DimensionMismatch("tried to assign $(length(v))-element array to length-$(length(inds)) destination"))
181181
end
182182
$(Expr(:block, exprs...))
183183
end
184184
end
185185

186-
@generated function _setindex!(a::StaticArray, v::StaticArray, ::Length{L}, inds::StaticVector{<:Any, Int}) where {L}
187-
exprs = [:(a[inds[$i]] = v[$i]) for i = 1:L]
186+
@generated function _setindex!(a::StaticArray, v::StaticArray, s::Size{S}, inds::StaticArray{<:Tuple, Int}) where {S}
187+
exprs = [:(a[inds[$i]] = v[$i]) for i = 1:prod(S)]
188188
return quote
189189
@_propagate_inbounds_meta
190-
if Length(typeof(v)) != L
191-
throw(DimensionMismatch("tried to assign $(length(v))-element array to length-$L destination"))
190+
if Length(typeof(v)) != Length(s)
191+
throw(DimensionMismatch("tried to assign $(length(v))-element array to length-$(length(inds)) destination"))
192192
end
193193
$(Expr(:block, exprs...))
194194
end

test/indexing.jl

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,9 @@ using StaticArrays, Test
99

1010
# Colon
1111
@test (@inferred getindex(sv,:)) === sv
12+
13+
# SArray
14+
@test (@inferred getindex(sv, SMatrix{2,2}(1,4,2,3))) === SMatrix{2,2}(4,7,5,6)
1215
end
1316

1417
@testset "Linear getindex() on SMatrix" begin
@@ -20,6 +23,9 @@ using StaticArrays, Test
2023

2124
# Colon
2225
@test (@inferred getindex(sm,:)) === sv
26+
27+
# SArray
28+
@test (@inferred getindex(sm, SMatrix{2,2}(1,4,2,3))) === SMatrix{2,2}(4,7,5,6)
2329
end
2430

2531
@testset "Linear getindex()/setindex!() on MVector" begin
@@ -39,6 +45,9 @@ using StaticArrays, Test
3945
mv = MVector(0,0,0)
4046
@test (mv[SVector(1,3)] = SVector(4, 5); (@inferred mv == MVector(4,0,5)))
4147

48+
mv = MVector(0,0,0)
49+
@test (mv[SMatrix{2,1}(1,3)] = SMatrix{2,1}(4, 5); (@inferred mv == MVector(4,0,5)))
50+
4251
# Colon
4352
mv = MVector{4,Int}(undef)
4453
@test (mv[:] = vec; (@inferred getindex(mv, :))::MVector{4,Int} == MVector((4,5,6,7)))
@@ -61,6 +70,12 @@ using StaticArrays, Test
6170
# Colon
6271
mm = MMatrix{2,2,Int}(undef)
6372
@test (mm[:] = vec; (@inferred getindex(mm, :))::MVector{4,Int} == MVector((4,5,6,7)))
73+
74+
# SMatrix
75+
mm = MMatrix{2,2,Int}(undef)
76+
mi = MMatrix{2,2}(4,2,1,3)
77+
data = @SMatrix [4 5; 6 7]
78+
@test (mm[mi] = data; (@inferred getindex(mm, :))::MVector{4,Int} == MVector((5,6,7,4)))
6479
end
6580

6681
@testset "Linear getindex()/setindex!() with a SVector on an Array" begin

0 commit comments

Comments
 (0)