Skip to content
This repository was archived by the owner on Mar 12, 2021. It is now read-only.

Commit 0f79e2c

Browse files
committed
Implement findall.
1 parent 51c33d7 commit 0f79e2c

File tree

2 files changed

+53
-0
lines changed

2 files changed

+53
-0
lines changed

src/indexing.jl

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,3 +43,51 @@ function Base.getindex(xs::CuArray{T}, bools::CuArray{Bool}) where {T}
4343

4444
return ys
4545
end
46+
47+
48+
## findall
49+
50+
function Base.findall(bools::CuArray{Bool})
51+
indices = cumsum(bools)
52+
53+
n = _getindex(indices, length(indices))
54+
ys = CuArray{Int}(undef, n)
55+
56+
if n > 0
57+
num_threads = min(n, 256)
58+
num_blocks = ceil(Int, length(indices) / num_threads)
59+
60+
function kernel(ys::CuDeviceArray{Int}, bools, indices)
61+
i = threadIdx().x + (blockIdx().x - 1) * blockDim().x
62+
63+
if i <= length(bools) && bools[i]
64+
b = indices[i] # new position
65+
ys[b] = i
66+
67+
end
68+
69+
return
70+
end
71+
72+
function configurator(kernel)
73+
fun = kernel.fun
74+
config = launch_configuration(fun)
75+
blocks = cld(length(indices), config.threads)
76+
77+
return (threads=config.threads, blocks=blocks)
78+
end
79+
80+
@cuda config=configurator kernel(ys, bools, indices)
81+
end
82+
83+
unsafe_free!(indices)
84+
85+
return ys
86+
end
87+
88+
function Base.findall(f::Function, A::CuArray)
89+
bools = map(f, A)
90+
ys = findall(bools)
91+
unsafe_free!(bools)
92+
return ys
93+
end

test/base.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -359,3 +359,8 @@ end
359359
inds = rand(1:100, 150, 150)
360360
@test testf(x->permutedims(view(x, inds, :), (3, 2, 1)), rand(100, 100))
361361
end
362+
363+
@testset "findall" begin
364+
@test testf(x->findall(x), rand(Bool, 100))
365+
@test testf(x->findall(y->y>0.5, x), rand(100))
366+
end

0 commit comments

Comments
 (0)