Skip to content
This repository was archived by the owner on Mar 12, 2021. It is now read-only.

Commit 18be296

Browse files
committed
Use unified memory for array allocations.
1 parent a794963 commit 18be296

File tree

3 files changed

+28
-9
lines changed

3 files changed

+28
-9
lines changed

src/indexing.jl

Lines changed: 25 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,31 @@
11
import GPUArrays: allowscalar, @allowscalar
22

3-
function _getindex(xs::CuArray{T}, i::Integer) where T
4-
buf = Array{T}(undef)
5-
copyto!(buf, 1, xs, i, 1)
6-
buf[]
3+
4+
## unified memory indexing
5+
6+
# TODO: needs to think about coherency -- otherwise this might crash since it doesn't sync
7+
# also, this optim would be relevant for CuArray<->Array memcpy as well.
8+
9+
function GPUArrays._getindex(xs::CuArray{T}, i::Integer) where T
10+
buf = buffer(xs)
11+
if isa(buf, Mem.UnifiedBuffer)
12+
ptr = convert(Ptr{T}, buffer(xs))
13+
unsafe_load(ptr, i)
14+
else
15+
val = Array{T}(undef)
16+
copyto!(val, 1, xs, i, 1)
17+
val[]
18+
end
719
end
820

9-
function _setindex!(xs::CuArray{T}, v::T, i::Integer) where T
10-
copyto!(xs, i, T[v], 1, 1)
21+
function GPUArrays._setindex!(xs::CuArray{T}, v::T, i::Integer) where T
22+
buf = buffer(xs)
23+
if isa(buf, Mem.UnifiedBuffer)
24+
ptr = convert(Ptr{T}, buffer(xs))
25+
unsafe_store!(ptr, v, i)
26+
else
27+
copyto!(xs, i, T[v], 1, 1)
28+
end
1129
end
1230

1331

@@ -19,7 +37,7 @@ function Base.getindex(xs::CuArray{T}, bools::CuArray{Bool}) where {T}
1937
bools = reshape(bools, prod(size(bools)))
2038
indices = cumsum(bools) # unique indices for elements that are true
2139

22-
n = _getindex(indices, length(indices)) # number that are true
40+
n = GPUArrays._getindex(indices, length(indices)) # number that are true
2341
ys = CuArray{T}(undef, n)
2442

2543
if n > 0

src/memory.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ function actual_alloc(bytes)
5050
# try the actual allocation
5151
try
5252
alloc_stats.actual_time += Base.@elapsed begin
53-
@timeit alloc_to "alloc" buf = Mem.alloc(Mem.Device, bytes)
53+
@timeit alloc_to "alloc" buf = Mem.alloc(Mem.Unified, bytes)
5454
end
5555
@assert sizeof(buf) == bytes
5656
alloc_stats.actual_nalloc += 1

src/solver/CUSOLVER.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
module CUSOLVER
22

33
using ..CuArrays
4-
using ..CuArrays: libcusolver, active_context, _getindex, unsafe_free!
4+
using ..CuArrays: libcusolver, active_context, unsafe_free!
5+
using GPUArrays: _getindex
56

67
using ..CUBLAS: cublasFillMode_t, cublasOperation_t, cublasSideMode_t, cublasDiagType_t
78
using ..CUSPARSE: cusparseMatDescr_t

0 commit comments

Comments
 (0)