Skip to content

Commit c0f0df3

Browse files
authored
Make sortperm! resilient to type mismatches. (#2051)
1 parent 5af1a7a commit c0f0df3

File tree

2 files changed

+10
-9
lines changed

2 files changed

+10
-9
lines changed

src/sorting.jl

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -977,15 +977,14 @@ function Base.partialsort(c::AnyCuArray, k::Union{Integer, OrdinalRange}; kwargs
977977
return partialsort!(copy(c), k; kwargs...)
978978
end
979979

980-
function Base.sortperm!(I::AnyCuArray{T}, c::AnyCuArray; initialized=false, kwargs...) where T
981-
if length(I) != length(c)
982-
throw(ArgumentError("index vector must have the same length/indices as the source vector"))
983-
end
980+
function Base.sortperm!(ix::AnyCuArray{T}, A::AnyCuArray; initialized=false, kwargs...) where T
981+
axes(ix) == axes(A) || throw(ArgumentError("index array must have the same size/axes as the source array, $(axes(ix)) != $(axes(A))"))
982+
984983
if !initialized
985-
I .= one(T):T(length(I))
984+
ix .= LinearIndices(A)
986985
end
987-
bitonic_sort!((c, I); kwargs...)
988-
return I
986+
bitonic_sort!((A, ix); kwargs...)
987+
return ix
989988
end
990989

991990
function Base.sortperm(c::AnyCuArray; kwargs...)

test/base/sorting.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -243,15 +243,15 @@ function check_sortperm!(i, T, N; kwargs...)
243243
I = CuArray(i)
244244
a = rand(T, N)
245245
c = CuArray(a)
246-
CUDA.@sync sortperm!(I, c; kwargs...)
246+
sortperm!(I, c; kwargs...)
247247
return Array(I) == sortperm!(i, a; kwargs...)
248248
end
249249

250250

251251
function check_sortperm(T, N; kwargs...)
252252
a = rand(T, N)
253253
c = CuArray(a)
254-
I = CUDA.@sync sortperm(c; kwargs...)
254+
I = sortperm(c; kwargs...)
255255
return Array(I) == sortperm(a; kwargs...)
256256
end
257257

@@ -397,6 +397,8 @@ end
397397
@test check_sortperm!(collect(Int32(1):Int32(1000000)), Float32, 1000000; initialized=false)
398398
# expected error case
399399
@test_throws ArgumentError sortperm!(CuArray(1:3), CuArray(1:4))
400+
# mismatched types (JuliaGPU/CUDA.jl#2046)
401+
@test check_sortperm!(collect(UInt64(1):UInt64(1000000)), Int64, 1000000)
400402
end
401403

402404
end

0 commit comments

Comments
 (0)