diff --git a/src/host/indexing.jl b/src/host/indexing.jl index 8aee8b9d..64a8988c 100644 --- a/src/host/indexing.jl +++ b/src/host/indexing.jl @@ -172,6 +172,7 @@ function Base._unsafe_setindex!(::IndexStyle, A::Base.ReshapedArray{<:Any, <:Any return vectorized_setindex!(A, x, Base.ensure_indexable(Is)...) end + # find* # simple array type that returns the index used to access an element, while @@ -249,3 +250,52 @@ end Base.findmax(a::AnyGPUArray; dims=:) = findminmax(Base.isless, a; init=typemin(eltype(a)), dims) Base.findmin(a::AnyGPUArray; dims=:) = findminmax(Base.isgreater, a; init=typemax(eltype(a)), dims) + +function Base.findall(bools::AbstractGPUArray{Bool}) + I = keytype(bools) + indices = cumsum(reshape(bools, prod(size(bools)))) + + n = @allowscalar indices[end] + ys = similar(bools, I, n) + + if n > 0 + @kernel function findall_kernel(ys, bools, indices) + i = @index(Global, Linear) + + @inbounds if i <= length(bools) && bools[i] + i′ = CartesianIndices(bools)[i] + b = indices[i] # new position + ys[b] = i′ + end + end + + kernel = findall_kernel(get_backend(ys)) + kernel(ys, bools, indices; ndrange=length(indices)) + end + + unsafe_free!(indices.data) + + return ys +end + +function Base.findall(f::Function, A::AbstractGPUArray) + bools = map(f, A) + ys = findall(bools) + unsafe_free!(bools) + return ys +end + + +# logical indexing + +# we cannot use Base.LogicalIndex, which does not support indexing but requires iteration. +# TODO: it should still be possible to use the same technique; +# Base.LogicalIndex basically contains the same as our `findall` here does. +Base.to_index(::AbstractGPUArray, I::AbstractArray{Bool}) = findall(I) +if VERSION >= v"1.11.0-DEV.1157" + Base.to_indices(A::AbstractGPUArray, I::Tuple{AbstractArray{Bool}}) = (Base.to_index(A, I[1]),) +else + Base.to_indices(A::AbstractGPUArray, inds, + I::Tuple{Union{Array{Bool,N}, BitArray{N}}}) where {N} = + (Base.to_index(A, I[1]),) +end diff --git a/test/testsuite/indexing.jl b/test/testsuite/indexing.jl index bc19cd7e..50d67b84 100644 --- a/test/testsuite/indexing.jl +++ b/test/testsuite/indexing.jl @@ -243,4 +243,14 @@ end @test compare(argmin, AT, rand(Int, 10)) @test compare(argmin, AT, -rand(Int, 10)) end + + @testset "findall" begin + # 1D + @test compare(findall, AT, rand(Bool, 100)) + @test compare(x->findall(>(0.5f0), x), AT, rand(Float32, 100)) + + # ND + @test compare(findall, AT, rand(Bool, 10, 10)) + @test compare(x->findall(>(0.5f0), x), AT, rand(Float32, 10, 10)) + end end