Skip to content
This repository was archived by the owner on Mar 12, 2021. It is now read-only.

Commit 33a1495

Browse files
committed
Use unified memory for array allocations.
1 parent ebb8339 commit 33a1495

File tree

3 files changed

+28
-9
lines changed

3 files changed

+28
-9
lines changed

src/indexing.jl

Lines changed: 25 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,31 @@
11
import GPUArrays: allowscalar, @allowscalar
22

3-
function _getindex(xs::CuArray{T}, i::Integer) where T
4-
buf = Array{T}(undef)
5-
copyto!(buf, 1, xs, i, 1)
6-
buf[]
3+
4+
## unified memory indexing
5+
6+
# TODO: needs to think about coherency -- otherwise this might crash since it doesn't sync
7+
# also, this optim would be relevant for CuArray<->Array memcpy as well.
8+
9+
function GPUArrays._getindex(xs::CuArray{T}, i::Integer) where T
10+
buf = buffer(xs)
11+
if isa(buf, Mem.UnifiedBuffer)
12+
ptr = convert(Ptr{T}, buffer(xs))
13+
unsafe_load(ptr, i)
14+
else
15+
val = Array{T}(undef)
16+
copyto!(val, 1, xs, i, 1)
17+
val[]
18+
end
719
end
820

9-
function _setindex!(xs::CuArray{T}, v::T, i::Integer) where T
10-
copyto!(xs, i, T[v], 1, 1)
21+
function GPUArrays._setindex!(xs::CuArray{T}, v::T, i::Integer) where T
22+
buf = buffer(xs)
23+
if isa(buf, Mem.UnifiedBuffer)
24+
ptr = convert(Ptr{T}, buffer(xs))
25+
unsafe_store!(ptr, v, i)
26+
else
27+
copyto!(xs, i, T[v], 1, 1)
28+
end
1129
end
1230

1331

@@ -19,7 +37,7 @@ function Base.getindex(xs::CuArray{T}, bools::CuArray{Bool}) where {T}
1937
bools = reshape(bools, prod(size(bools)))
2038
indices = cumsum(bools) # unique indices for elements that are true
2139

22-
n = _getindex(indices, length(indices)) # number that are true
40+
n = GPUArrays._getindex(indices, length(indices)) # number that are true
2341
ys = CuArray{T}(undef, n)
2442

2543
if n > 0

src/memory.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -241,7 +241,7 @@ function try_cuda_alloc(bytes)
241241
buf = nothing
242242
try
243243
stats.cuda_time += Base.@elapsed begin
244-
buf = Mem.alloc(Mem.Device, bytes)
244+
buf = Mem.alloc(Mem.Unified, bytes)
245245
end
246246
stats.actual_nalloc += 1
247247
stats.actual_alloc += bytes

src/solver/CUSOLVER.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@ import CUDAdrv: CUDAdrv, CuContext, CuStream_t, CuPtr, PtrOrCuPtr, CU_NULL
44
import CUDAapi
55

66
using ..CuArrays
7-
using ..CuArrays: libcusolver, active_context, _getindex, unsafe_free!
7+
using ..CuArrays: libcusolver, active_context, unsafe_free!
8+
using GPUArrays: _getindex
89

910
using LinearAlgebra
1011
using SparseArrays

0 commit comments

Comments
 (0)