Skip to content
This repository was archived by the owner on Mar 12, 2021. It is now read-only.

Commit 3259594

Browse files
authored
Merge pull request #557 from JuliaGPU/tb/resize
Implement array resizing.
2 parents 398e563 + f900010 commit 3259594

File tree

2 files changed

+56
-0
lines changed

2 files changed

+56
-0
lines changed

src/array.jl

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -462,3 +462,36 @@ function Base.reverse(input::CuVector{T}, start=1, stop=length(input)) where {T}
462462

463463
return output
464464
end
465+
466+
467+
## resizing
468+
469+
"""
470+
resize!(a::CuVector, n::Int)
471+
472+
Resize `a` to contain `n` elements. If `n` is smaller than the current collection length,
473+
the first `n` elements will be retained. If `n` is larger, the new elements are not
474+
guaranteed to be initialized.
475+
476+
Several restrictions apply to which types of `CuArray`s can be resized:
477+
478+
- the array should be backed by the memory pool, and not have been constructed with `unsafe_wrap`
479+
- the array cannot be derived (view, reshape) from another array
480+
- the array cannot have any derived arrays itself
481+
482+
"""
483+
function Base.resize!(A::CuVector{T}, n::Int) where T
484+
A.parent === nothing || error("cannot resize derived CuArray")
485+
A.refcount == 1 || error("cannot resize shared CuArray")
486+
A.pooled || error("cannot resize wrapped CuArray")
487+
488+
ptr = convert(CuPtr{T}, alloc(n * sizeof(T)))
489+
m = min(length(A), n)
490+
unsafe_copyto!(ptr, pointer(A), m)
491+
492+
free(convert(CuPtr{Nothing}, pointer(A)))
493+
A.dims = (n,)
494+
A.ptr = ptr
495+
496+
A
497+
end

test/base.jl

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -449,3 +449,26 @@ end
449449
y = exp.(x)
450450
@test y isa CuArray{Complex{Float32}}
451451
end
452+
453+
@testset "resizing" begin
454+
a = CuArray([1,2,3])
455+
456+
resize!(a, 3)
457+
@test length(a) == 3
458+
@test Array(a) == [1,2,3]
459+
460+
resize!(a, 5)
461+
@test length(a) == 5
462+
@test Array(a)[1:3] == [1,2,3]
463+
464+
resize!(a, 2)
465+
@test length(a) == 2
466+
@test Array(a)[1:2] == [1,2]
467+
468+
b = view(a, 1:2)
469+
@test_throws ErrorException resize!(a, 2)
470+
@test_throws ErrorException resize!(b, 2)
471+
472+
c = unsafe_wrap(CuArray{Int}, pointer(b), 2)
473+
@test_throws ErrorException resize!(c, 2)
474+
end

0 commit comments

Comments
 (0)