Skip to content

Commit 58ef36a

Browse files
authored
Port find functions from CUDA (#436)
Changes: - use linear indices in the kernel, converting the output to cartesian if needed - create a helper EachIndex struct that retains dimensionality (previously, we were needlessly launching threads) - simplify NaN handling by using Base.isless and Base.isgreater - deobfuscate the implementation
1 parent 844b253 commit 58ef36a

File tree

2 files changed

+184
-0
lines changed

2 files changed

+184
-0
lines changed

src/host/indexing.jl

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,3 +87,96 @@ end
8787
return
8888
end
8989
end
90+
91+
92+
## find*
93+
94+
# simple array type that returns the index used to access an element, while
95+
# retaining the dimensionality of the original array. this can be used to
96+
# broadcast or reduce an array together with its indices, whereas normally
97+
# combining e.g. a 2x2 array with its 4-element eachindex array would result
98+
# in a 4x4 broadcast or reduction.
99+
struct EachIndex{T,N,IS} <: AbstractArray{T,N}
100+
dims::NTuple{N,Int}
101+
indices::IS
102+
end
103+
EachIndex(xs::AbstractArray) =
104+
EachIndex{typeof(firstindex(xs)), ndims(xs), typeof(eachindex(xs))}(
105+
size(xs), eachindex(xs))
106+
Base.size(ei::EachIndex) = ei.dims
107+
Base.getindex(ei::EachIndex, i::Int) = ei.indices[i]
108+
Base.IndexStyle(::Type{<:EachIndex}) = Base.IndexLinear()
109+
110+
function Base.findfirst(f::Function, xs::AnyGPUArray)
111+
indices = EachIndex(xs)
112+
dummy_index = first(indices)
113+
114+
# given two pairs of (istrue, index), return the one with the smallest index
115+
function reduction(t1, t2)
116+
(x, i), (y, j) = t1, t2
117+
if i > j
118+
t1, t2 = t2, t1
119+
(x, i), (y, j) = t1, t2
120+
end
121+
x && return t1
122+
y && return t2
123+
return (false, dummy_index)
124+
end
125+
126+
res = mapreduce((x, y)->(f(x), y), reduction, xs, indices;
127+
init = (false, dummy_index))
128+
if res[1]
129+
# out of consistency with Base.findarray, return a CartesianIndex
130+
# when the input is a multidimensional array
131+
ndims(xs) == 1 && return res[2]
132+
return CartesianIndices(xs)[res[2]]
133+
else
134+
return nothing
135+
end
136+
end
137+
138+
Base.findfirst(xs::AnyGPUArray{Bool}) = findfirst(identity, xs)
139+
140+
function findminmax(binop, xs::AnyGPUArray; init, dims)
141+
indices = EachIndex(xs)
142+
dummy_index = firstindex(xs)
143+
144+
function reduction(t1, t2)
145+
(x, i), (y, j) = t1, t2
146+
147+
binop(x, y) && return t2
148+
x == y && return (x, min(i, j))
149+
return t1
150+
end
151+
152+
@static if VERSION < v"1.7.0-DEV.119"
153+
# before JuliaLang/julia#35316, isless/isgreated did not order NaNs last
154+
function reduction(t1, t2)
155+
(x, i), (y, j) = t1, t2
156+
157+
isnan(x) && return t1
158+
isnan(y) && return t2
159+
160+
binop(x, y) && return t2
161+
x == y && return (x, min(i, j))
162+
return t1
163+
end
164+
end
165+
166+
if dims == Colon()
167+
res = mapreduce(tuple, reduction, xs, indices; init = (init, dummy_index))
168+
169+
# out of consistency with Base.findarray, return a CartesianIndex
170+
# when the input is a multidimensional array
171+
return (res[1], ndims(xs) == 1 ? res[2] : CartesianIndices(xs)[res[2]])
172+
else
173+
res = mapreduce(tuple, reduction, xs, indices;
174+
init = (init, dummy_index), dims=dims)
175+
vals = map(x->x[1], res)
176+
inds = map(x->ndims(xs) == 1 ? x[2] : CartesianIndices(xs)[x[2]], res)
177+
return (vals, inds)
178+
end
179+
end
180+
181+
Base.findmax(a::AnyGPUArray; dims=:) = findminmax(Base.isless, a; init=typemin(eltype(a)), dims)
182+
Base.findmin(a::AnyGPUArray; dims=:) = findminmax(Base.isgreater, a; init=typemax(eltype(a)), dims)

test/testsuite/indexing.jl

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,3 +119,94 @@ end
119119
@test compare((X,Y)->(X[1,:] = Y), AT, zeros(Float32, 2,2), ones(Float32, 2))
120120
end
121121
end
122+
123+
@testsuite "indexing find" (AT, eltypes)->begin
124+
@testset "findfirst" begin
125+
# 1D
126+
@test compare(findfirst, AT, rand(Bool, 100))
127+
@test compare(x->findfirst(>(0.5f0), x), AT, rand(Float32, 100))
128+
let x = fill(false, 10)
129+
@test findfirst(x) == findfirst(AT(x))
130+
end
131+
132+
# ND
133+
let x = rand(Bool, 10, 10)
134+
@test findfirst(x) == findfirst(AT(x))
135+
end
136+
let x = rand(Float32, 10, 10)
137+
@test findfirst(>(0.5f0), x) == findfirst(>(0.5f0), AT(x))
138+
end
139+
end
140+
141+
@testset "findmax & findmin" begin
142+
let x = rand(Float32, 100)
143+
@test findmax(x) == findmax(AT(x))
144+
@test findmax(x; dims=1) == Array.(findmax(AT(x); dims=1))
145+
146+
x[32] = x[33] = x[55] = x[66] = NaN32
147+
@test isequal(findmax(x), findmax(AT(x)))
148+
@test isequal(findmax(x; dims=1), Array.(findmax(AT(x); dims=1)))
149+
end
150+
let x = rand(Float32, 10, 10)
151+
@test findmax(x) == findmax(AT(x))
152+
@test findmax(x; dims=1) == Array.(findmax(AT(x); dims=1))
153+
@test findmax(x; dims=2) == Array.(findmax(AT(x); dims=2))
154+
155+
x[rand(CartesianIndices((10, 10)), 10)] .= NaN
156+
@test isequal(findmax(x), findmax(AT(x)))
157+
@test isequal(findmax(x; dims=1), Array.(findmax(AT(x); dims=1)))
158+
end
159+
let x = rand(Float32, 10, 10, 10)
160+
@test findmax(x) == findmax(AT(x))
161+
@test findmax(x; dims=1) == Array.(findmax(AT(x); dims=1))
162+
@test findmax(x; dims=2) == Array.(findmax(AT(x); dims=2))
163+
@test findmax(x; dims=3) == Array.(findmax(AT(x); dims=3))
164+
165+
x[rand(CartesianIndices((10, 10, 10)), 20)] .= NaN
166+
@test isequal(findmax(x), findmax(AT(x)))
167+
@test isequal(findmax(x; dims=1), Array.(findmax(AT(x); dims=1)))
168+
@test isequal(findmax(x; dims=2), Array.(findmax(AT(x); dims=2)))
169+
@test isequal(findmax(x; dims=3), Array.(findmax(AT(x); dims=3)))
170+
end
171+
172+
let x = rand(Float32, 100)
173+
@test findmin(x) == findmin(AT(x))
174+
@test findmin(x; dims=1) == Array.(findmin(AT(x); dims=1))
175+
176+
x[32] = x[33] = x[55] = x[66] = NaN32
177+
@test isequal(findmin(x), findmin(AT(x)))
178+
@test isequal(findmin(x; dims=1), Array.(findmin(AT(x); dims=1)))
179+
end
180+
let x = rand(Float32, 10, 10)
181+
@test findmin(x) == findmin(AT(x))
182+
@test findmin(x; dims=1) == Array.(findmin(AT(x); dims=1))
183+
@test findmin(x; dims=2) == Array.(findmin(AT(x); dims=2))
184+
185+
x[rand(CartesianIndices((10, 10)), 10)] .= NaN
186+
@test isequal(findmin(x), findmin(AT(x)))
187+
@test isequal(findmin(x; dims=1), Array.(findmin(AT(x); dims=1)))
188+
@test isequal(findmin(x; dims=2), Array.(findmin(AT(x); dims=2)))
189+
@test isequal(findmin(x; dims=3), Array.(findmin(AT(x); dims=3)))
190+
end
191+
let x = rand(Float32, 10, 10, 10)
192+
@test findmin(x) == findmin(AT(x))
193+
@test findmin(x; dims=1) == Array.(findmin(AT(x); dims=1))
194+
@test findmin(x; dims=2) == Array.(findmin(AT(x); dims=2))
195+
@test findmin(x; dims=3) == Array.(findmin(AT(x); dims=3))
196+
197+
x[rand(CartesianIndices((10, 10, 10)), 20)] .= NaN
198+
@test isequal(findmin(x), findmin(AT(x)))
199+
@test isequal(findmin(x; dims=1), Array.(findmin(AT(x); dims=1)))
200+
@test isequal(findmin(x; dims=2), Array.(findmin(AT(x); dims=2)))
201+
@test isequal(findmin(x; dims=3), Array.(findmin(AT(x); dims=3)))
202+
end
203+
end
204+
205+
@testset "argmax & argmin" begin
206+
@test compare(argmax, AT, rand(Int, 10))
207+
@test compare(argmax, AT, -rand(Int, 10))
208+
209+
@test compare(argmin, AT, rand(Int, 10))
210+
@test compare(argmin, AT, -rand(Int, 10))
211+
end
212+
end

0 commit comments

Comments
 (0)