Skip to content

Commit 4c13a99

Browse files
authored
Rework host indexing. (#499)
1 parent 5f40711 commit 4c13a99

File tree

3 files changed

+105
-61
lines changed

3 files changed

+105
-61
lines changed

src/host/base.jl

Lines changed: 1 addition & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -308,21 +308,10 @@ function Adapt.adapt_storage(to::ToGPU, xs::Array)
308308
arr
309309
end
310310

311-
# we don't really want an array, so don't call `adapt(Array, ...)`,
312-
# but just want GPUArray indices to get downloaded back to the CPU.
313-
# this makes sure we preserve array-like containers, like Base.Slice.
314-
struct BackToCPU end
315-
Adapt.adapt_storage(::BackToCPU, xs::AbstractGPUArray) = convert(Array, xs)
316-
317311
@inline function Base.view(A::AbstractGPUArray, I::Vararg{Any,N}) where {N}
318312
J = to_indices(A, I)
319-
@boundscheck begin
320-
# Base's boundscheck accesses the indices, so make sure they reside on the CPU.
321-
# this is expensive, but it's a bounds check after all.
322-
J_cpu = map(j->adapt(BackToCPU(), j), J)
323-
checkbounds(A, J_cpu...)
324-
end
325313
J_gpu = map(j->adapt(ToGPU(A), j), J)
314+
@boundscheck checkbounds(A, J...)
326315
unsafe_view(A, J_gpu, GPUIndexStyle(I...))
327316
end
328317

src/host/indexing.jl

Lines changed: 98 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -1,51 +1,77 @@
11
# host-level indexing
22

33

4-
# basic indexing with integers
4+
# indexing operators
55

66
Base.IndexStyle(::Type{<:AbstractGPUArray}) = Base.IndexLinear()
77

8-
function Base.getindex(xs::AbstractGPUArray{T}, I::Integer...) where T
8+
vectorized_indices(Is::Union{Integer,CartesianIndex}...) = Val{false}()
9+
vectorized_indices(Is...) = Val{true}()
10+
11+
# TODO: re-use Base functionality for the conversion of indices to a linear index,
12+
# by only implementing `getindex(A, ::Int)` etc. this is difficult due to
13+
# ambiguities with the vectorized method that can take any index type.
14+
15+
Base.@propagate_inbounds Base.getindex(A::AbstractGPUArray, Is...) =
16+
_getindex(vectorized_indices(Is...), A, to_indices(A, Is)...)
17+
Base.@propagate_inbounds _getindex(::Val{false}, A::AbstractGPUArray, Is...) =
18+
scalar_getindex(A, to_indices(A, Is)...)
19+
Base.@propagate_inbounds _getindex(::Val{true}, A::AbstractGPUArray, Is...) =
20+
vectorized_getindex(A, to_indices(A, Is)...)
21+
22+
Base.@propagate_inbounds Base.setindex!(A::AbstractGPUArray, v, Is...) =
23+
_setindex!(vectorized_indices(Is...), A, v, to_indices(A, Is)...)
24+
Base.@propagate_inbounds _setindex!(::Val{false}, A::AbstractGPUArray, v, Is...) =
25+
scalar_setindex!(A, v, to_indices(A, Is)...)
26+
Base.@propagate_inbounds _setindex!(::Val{true}, A::AbstractGPUArray, v, Is...) =
27+
vectorized_setindex!(A, v, to_indices(A, Is)...)
28+
29+
## scalar indexing
30+
31+
function scalar_getindex(A::AbstractGPUArray{T}, Is...) where T
32+
@boundscheck checkbounds(A, Is...)
33+
I = Base._to_linear_index(A, Is...)
34+
getindex(A, I)
35+
end
36+
37+
function scalar_setindex!(A::AbstractGPUArray{T}, v, Is...) where T
38+
@boundscheck checkbounds(A, Is...)
39+
I = Base._to_linear_index(A, Is...)
40+
setindex!(A, v, I)
41+
end
42+
43+
# we still dispatch to `Base.getindex(a, ::Int)` etc so that there's a single method to
44+
# override when a back-end (e.g. with unified memory) wants to allow scalar indexing.
45+
46+
function Base.getindex(A::AbstractGPUArray{T}, I::Int) where T
47+
@boundscheck checkbounds(A, I)
948
assertscalar("getindex")
10-
i = Base._to_linear_index(xs, I...)
1149
x = Array{T}(undef, 1)
12-
copyto!(x, 1, xs, i, 1)
50+
copyto!(x, 1, A, I, 1)
1351
return x[1]
1452
end
1553

16-
function Base.setindex!(xs::AbstractGPUArray{T}, v::T, I::Integer...) where T
54+
function Base.setindex!(A::AbstractGPUArray{T}, v, I::Int) where T
55+
@boundscheck checkbounds(A, I)
1756
assertscalar("setindex!")
18-
i = Base._to_linear_index(xs, I...)
1957
x = T[v]
20-
copyto!(xs, i, x, 1, 1)
21-
return xs
58+
copyto!(A, I, x, 1, 1)
59+
return A
2260
end
2361

24-
Base.setindex!(xs::AbstractGPUArray, v, I::Integer...) =
25-
setindex!(xs, convert(eltype(xs), v), I...)
26-
62+
## vectorized indexing
2763

28-
# basic indexing with cartesian indices
29-
30-
Base.@propagate_inbounds Base.getindex(A::AbstractGPUArray, I::Union{Integer, CartesianIndex}...) =
31-
A[Base.to_indices(A, I)...]
32-
Base.@propagate_inbounds Base.setindex!(A::AbstractGPUArray, v, I::Union{Integer, CartesianIndex}...) =
33-
(A[Base.to_indices(A, I)...] = v; A)
34-
35-
36-
# generalized multidimensional indexing
37-
38-
Base.getindex(A::AbstractGPUArray, I...) = _getindex(A, to_indices(A, I)...)
39-
40-
function _getindex(src::AbstractGPUArray, Is...)
64+
function vectorized_getindex(src::AbstractGPUArray, Is...)
4165
shape = Base.index_shape(Is...)
4266
dest = similar(src, shape)
4367
any(isempty, Is) && return dest # indexing with empty array
4468
idims = map(length, Is)
4569

46-
AT = typeof(src).name.wrapper
4770
# NOTE: we are pretty liberal here supporting non-GPU indices...
48-
gpu_call(getindex_kernel, dest, src, idims, adapt(AT, Is)...)
71+
Is = map(x->adapt(ToGPU(src), x), Is)
72+
@boundscheck checkbounds(src, Is...)
73+
74+
gpu_call(getindex_kernel, dest, src, idims, Is...)
4975
return dest
5076
end
5177

@@ -61,9 +87,7 @@ end
6187
end
6288
end
6389

64-
Base.setindex!(A::AbstractGPUArray, v, I...) = _setindex!(A, v, to_indices(A, I)...)
65-
66-
function _setindex!(dest::AbstractGPUArray, src, Is...)
90+
function vectorized_setindex!(dest::AbstractGPUArray, src, Is...)
6791
isempty(Is) && return dest
6892
idims = length.(Is)
6993
len = prod(idims)
@@ -76,9 +100,11 @@ function _setindex!(dest::AbstractGPUArray, src, Is...)
76100
end
77101
end
78102

79-
AT = typeof(dest).name.wrapper
80-
# NOTE: we are pretty liberal here supporting non-GPU sources and indices...
81-
gpu_call(setindex_kernel, dest, adapt(AT, src), idims, len, adapt(AT, Is)...;
103+
# NOTE: we are pretty liberal here supporting non-GPU indices...
104+
Is = map(x->adapt(ToGPU(dest), x), Is)
105+
@boundscheck checkbounds(dest, Is...)
106+
107+
gpu_call(setindex_kernel, dest, adapt(ToGPU(dest), src), idims, len, Is...;
82108
elements=len)
83109
return dest
84110
end
@@ -96,7 +122,30 @@ end
96122
end
97123

98124

99-
## find*
125+
# bounds checking
126+
127+
# indices residing on the GPU should be bounds-checked on the GPU to avoid iteration.
128+
129+
# not all wrapped GPU arrays make sense as indices, so we use a subset of `AnyGPUArray`
130+
const IndexGPUArray{T} = Union{AbstractGPUArray{T},
131+
SubArray{T, <:Any, <:AbstractGPUArray},
132+
LinearAlgebra.Adjoint{T}}
133+
134+
@inline function Base.checkindex(::Type{Bool}, inds::AbstractUnitRange, I::IndexGPUArray)
135+
all(broadcast(I) do i
136+
Base.checkindex(Bool, inds, i)
137+
end)
138+
end
139+
140+
@inline function Base.checkindex(::Type{Bool}, inds::Tuple,
141+
I::IndexGPUArray{<:CartesianIndex})
142+
all(broadcast(I) do i
143+
Base.checkbounds_indices(Bool, inds, (i,))
144+
end)
145+
end
146+
147+
148+
# find*
100149

101150
# simple array type that returns the index used to access an element, while
102151
# retaining the dimensionality of the original array. this can be used to
@@ -107,15 +156,15 @@ struct EachIndex{T,N,IS} <: AbstractArray{T,N}
107156
dims::NTuple{N,Int}
108157
indices::IS
109158
end
110-
EachIndex(xs::AbstractArray) =
111-
EachIndex{typeof(firstindex(xs)), ndims(xs), typeof(eachindex(xs))}(
112-
size(xs), eachindex(xs))
159+
EachIndex(A::AbstractArray) =
160+
EachIndex{typeof(firstindex(A)), ndims(A), typeof(eachindex(A))}(
161+
size(A), eachindex(A))
113162
Base.size(ei::EachIndex) = ei.dims
114163
Base.getindex(ei::EachIndex, i::Int) = ei.indices[i]
115164
Base.IndexStyle(::Type{<:EachIndex}) = Base.IndexLinear()
116165

117-
function Base.findfirst(f::Function, xs::AnyGPUArray)
118-
indices = EachIndex(xs)
166+
function Base.findfirst(f::Function, A::AnyGPUArray)
167+
indices = EachIndex(A)
119168
dummy_index = first(indices)
120169

121170
# given two pairs of (istrue, index), return the one with the smallest index
@@ -130,23 +179,23 @@ function Base.findfirst(f::Function, xs::AnyGPUArray)
130179
return (false, dummy_index)
131180
end
132181

133-
res = mapreduce((x, y)->(f(x), y), reduction, xs, indices;
182+
res = mapreduce((x, y)->(f(x), y), reduction, A, indices;
134183
init = (false, dummy_index))
135184
if res[1]
136185
# out of consistency with Base.findarray, return a CartesianIndex
137186
# when the input is a multidimensional array
138-
ndims(xs) == 1 && return res[2]
139-
return CartesianIndices(xs)[res[2]]
187+
ndims(A) == 1 && return res[2]
188+
return CartesianIndices(A)[res[2]]
140189
else
141190
return nothing
142191
end
143192
end
144193

145-
Base.findfirst(xs::AnyGPUArray{Bool}) = findfirst(identity, xs)
194+
Base.findfirst(A::AnyGPUArray{Bool}) = findfirst(identity, A)
146195

147-
function findminmax(binop, xs::AnyGPUArray; init, dims)
148-
indices = EachIndex(xs)
149-
dummy_index = firstindex(xs)
196+
function findminmax(binop, A::AnyGPUArray; init, dims)
197+
indices = EachIndex(A)
198+
dummy_index = firstindex(A)
150199

151200
function reduction(t1, t2)
152201
(x, i), (y, j) = t1, t2
@@ -157,16 +206,16 @@ function findminmax(binop, xs::AnyGPUArray; init, dims)
157206
end
158207

159208
if dims == Colon()
160-
res = mapreduce(tuple, reduction, xs, indices; init = (init, dummy_index))
209+
res = mapreduce(tuple, reduction, A, indices; init = (init, dummy_index))
161210

162211
# out of consistency with Base.findarray, return a CartesianIndex
163212
# when the input is a multidimensional array
164-
return (res[1], ndims(xs) == 1 ? res[2] : CartesianIndices(xs)[res[2]])
213+
return (res[1], ndims(A) == 1 ? res[2] : CartesianIndices(A)[res[2]])
165214
else
166-
res = mapreduce(tuple, reduction, xs, indices;
215+
res = mapreduce(tuple, reduction, A, indices;
167216
init = (init, dummy_index), dims=dims)
168217
vals = map(x->x[1], res)
169-
inds = map(x->ndims(xs) == 1 ? x[2] : CartesianIndices(xs)[x[2]], res)
218+
inds = map(x->ndims(A) == 1 ? x[2] : CartesianIndices(A)[x[2]], res)
170219
return (vals, inds)
171220
end
172221
end

test/testsuite/indexing.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,12 @@ end
129129
@test_throws DimensionMismatch x[1:9,1:9,:,:] = y
130130
end
131131

132+
@testset "mismatching axes/indices" begin
133+
a = rand(Float32, 1,1)
134+
@test compare(a->a[1:1], AT, a)
135+
@test compare(a->a[1:1,1:1], AT, a)
136+
@test compare(a->a[1:1,1:1,1:1], AT, a)
137+
end
132138
end
133139

134140
@testsuite "indexing find" (AT, eltypes)->begin

0 commit comments

Comments
 (0)