Skip to content
This repository was archived by the owner on Mar 12, 2021. It is now read-only.

Use unified memory for array allocations. #336

Draft
wants to merge 5 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion Manifest.toml
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,9 @@ version = "1.2.0"

[[CUDAdrv]]
deps = ["CUDAapi", "Libdl", "Printf"]
git-tree-sha1 = "9ce99b5732c70e06ed97c042187baed876fb1698"
git-tree-sha1 = "f4420a71d8847fa13ad70d744fe5c3696b7efca0"
repo-rev = "master"
repo-url = "https://github.com/JuliaGPU/CUDAdrv.jl.git"
uuid = "c5f51814-7f29-56b8-a69c-e4d8f6be1fde"
version = "3.1.0"

Expand Down
8 changes: 6 additions & 2 deletions src/CuArrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,8 @@ function __init__()
# package integrations
@require ForwardDiff="f6369f11-7733-5829-9624-2563aa707210" include("forwarddiff.jl")

# update the active context when we switch devices
callback = (::CuDevice, ctx::CuContext) -> begin
callback = (dev::CuDevice, ctx::CuContext) -> begin
# update the active context
active_context[] = ctx

# wipe the active handles
Expand All @@ -103,6 +103,10 @@ function __init__()
CURAND._generator[] = nothing
CUDNN._handle[] = C_NULL
CUTENSOR._handle[] = C_NULL

# update the coherent memory access indicator
coherent[] = CUDAdrv.version() >= v"9.0" &&
attribute(dev, CUDAdrv.CONCURRENT_MANAGED_ACCESS)
end
push!(CUDAnative.device!_listeners, callback)

Expand Down
47 changes: 37 additions & 10 deletions src/indexing.jl
Original file line number Diff line number Diff line change
@@ -1,13 +1,40 @@
import GPUArrays: allowscalar, @allowscalar

function _getindex(xs::CuArray{T}, i::Integer) where T
buf = Array{T}(undef)
copyto!(buf, 1, xs, i, 1)
buf[]

## unified memory indexing

# > Simultaneous access to managed memory from the CPU and GPUs of compute capability lower
# > than 6.0 is not possible. This is because pre-Pascal GPUs lack hardware page faulting,
# > so coherence can’t be guaranteed. On these GPUs, an access from the CPU while a kernel
# > is running will cause a segmentation fault.
#
# > On Pascal and later GPUs, the CPU and the GPU can simultaneously access managed memory,
# > since they can both handle page faults; however, it is up to the application developer
# > to ensure there are no race conditions caused by simultaneous accesses.
const coherent = Ref(false)

function GPUArrays._getindex(xs::CuArray{T}, i::Integer) where T
buf = buffer(xs)
if isa(buf, Mem.UnifiedBuffer)
coherent[] || CUDAdrv.synchronize()
ptr = convert(Ptr{T}, buf)
unsafe_load(ptr, i)
else
val = Array{T}(undef)
copyto!(val, 1, xs, i, 1)
val[]
end
end

function _setindex!(xs::CuArray{T}, v::T, i::Integer) where T
copyto!(xs, i, T[v], 1, 1)
function GPUArrays._setindex!(xs::CuArray{T}, v::T, i::Integer) where T
buf = buffer(xs)
if isa(buf, Mem.UnifiedBuffer)
coherent[] || CUDAdrv.synchronize()
ptr = convert(Ptr{T}, buf)
unsafe_store!(ptr, v, i)
else
copyto!(xs, i, T[v], 1, 1)
end
end


Expand All @@ -17,9 +44,9 @@ Base.getindex(xs::CuArray, bools::AbstractArray{Bool}) = getindex(xs, CuArray(bo

function Base.getindex(xs::CuArray{T}, bools::CuArray{Bool}) where {T}
bools = reshape(bools, prod(size(bools)))
indices = cumsum(bools) # unique indices for elements that are true
indices = @sync cumsum(bools) # unique indices for elements that are true

n = _getindex(indices, length(indices)) # number that are true
n = GPUArrays._getindex(indices, length(indices)) # number that are true
ys = CuArray{T}(undef, n)

if n > 0
Expand Down Expand Up @@ -55,9 +82,9 @@ end
## findall

function Base.findall(bools::CuArray{Bool})
indices = cumsum(bools)
indices = @sync cumsum(bools)

n = _getindex(indices, length(indices))
n = GPUArrays._getindex(indices, length(indices))
ys = CuArray{Int}(undef, n)

if n > 0
Expand Down
2 changes: 1 addition & 1 deletion src/memory.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ function actual_alloc(bytes)
# try the actual allocation
try
alloc_stats.actual_time += Base.@elapsed begin
@timeit alloc_to "alloc" buf = Mem.alloc(Mem.Device, bytes)
@timeit alloc_to "alloc" buf = Mem.alloc(Mem.Unified, bytes)
end
@assert sizeof(buf) == bytes
alloc_stats.actual_nalloc += 1
Expand Down
3 changes: 2 additions & 1 deletion src/solver/CUSOLVER.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
module CUSOLVER

using ..CuArrays
using ..CuArrays: libcusolver, active_context, _getindex, unsafe_free!
using ..CuArrays: libcusolver, active_context, unsafe_free!, @sync
using GPUArrays: _getindex

using ..CUBLAS: cublasFillMode_t, cublasOperation_t, cublasSideMode_t, cublasDiagType_t
using ..CUSPARSE: cusparseMatDescr_t
Expand Down
40 changes: 20 additions & 20 deletions src/solver/dense.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ for (bname, fname,elty) in ((:cusolverDnSpotrf_bufferSize, :cusolverDnSpotrf, :F

buffer = CuArray{$elty}(undef, bufSize[])
devinfo = CuArray{Cint}(undef, 1)
$fname(dense_handle(), cuuplo, n, A, lda, buffer, bufSize[], devinfo)
@sync $fname(dense_handle(), cuuplo, n, A, lda, buffer, bufSize[], devinfo)
unsafe_free!(buffer)

info = BlasInt(_getindex(devinfo, 1))
Expand Down Expand Up @@ -65,7 +65,7 @@ for (fname,elty) in ((:cusolverDnSpotrs, :Float32),
ldb = max(1, stride(B, 2))

devinfo = CuArray{Cint}(undef, 1)
$fname(dense_handle(), cuuplo, n, nrhs, A, lda, B, ldb, devinfo)
@sync $fname(dense_handle(), cuuplo, n, nrhs, A, lda, B, ldb, devinfo)

info = _getindex(devinfo, 1)
unsafe_free!(devinfo)
Expand All @@ -91,7 +91,7 @@ for (bname, fname,elty) in ((:cusolverDnSgetrf_bufferSize, :cusolverDnSgetrf, :F
buffer = CuArray{$elty}(undef, bufSize[])
devipiv = CuArray{Cint}(undef, min(m,n))
devinfo = CuArray{Cint}(undef, 1)
$fname(dense_handle(), m, n, A, lda, buffer, devipiv, devinfo)
@sync $fname(dense_handle(), m, n, A, lda, buffer, devipiv, devinfo)
unsafe_free!(buffer)

info = _getindex(devinfo, 1)
Expand Down Expand Up @@ -122,7 +122,7 @@ for (bname, fname,elty) in ((:cusolverDnSgeqrf_bufferSize, :cusolverDnSgeqrf, :F
buffer = CuArray{$elty}(undef, bufSize[])
tau = CuArray{$elty}(undef, min(m, n))
devinfo = CuArray{Cint}(undef, 1)
$fname(dense_handle(), m, n, A, lda, tau, buffer, bufSize[], devinfo)
@sync $fname(dense_handle(), m, n, A, lda, tau, buffer, bufSize[], devinfo)
unsafe_free!(buffer)

info = _getindex(devinfo, 1)
Expand Down Expand Up @@ -153,7 +153,7 @@ for (bname, fname,elty) in ((:cusolverDnSsytrf_bufferSize, :cusolverDnSsytrf, :F
buffer = CuArray{$elty}(undef, bufSize[])
devipiv = CuArray{Cint}(undef, n)
devinfo = CuArray{Cint}(undef, 1)
$fname(dense_handle(), cuuplo, n, A, lda, devipiv, buffer, bufSize[], devinfo)
@sync $fname(dense_handle(), cuuplo, n, A, lda, devipiv, buffer, bufSize[], devinfo)
unsafe_free!(buffer)

info = _getindex(devinfo, 1)
Expand Down Expand Up @@ -192,7 +192,7 @@ for (fname,elty) in ((:cusolverDnSgetrs, :Float32),
ldb = max(1, stride(B, 2))

devinfo = CuArray{Cint}(undef, 1)
$fname(dense_handle(), cutrans, n, nrhs, A, lda, ipiv, B, ldb, devinfo)
@sync $fname(dense_handle(), cutrans, n, nrhs, A, lda, ipiv, B, ldb, devinfo)

info = _getindex(devinfo, 1)
unsafe_free!(devinfo)
Expand Down Expand Up @@ -240,8 +240,8 @@ for (bname, fname, elty) in ((:cusolverDnSormqr_bufferSize, :cusolverDnSormqr, :

buffer = CuArray{$elty}(undef, bufSize[])
devinfo = CuArray{Cint}(undef, 1)
$fname(dense_handle(), cuside, cutrans, m, n, k, A, lda, tau, C, ldc, buffer,
bufSize[], devinfo)
@sync $fname(dense_handle(), cuside, cutrans, m, n, k, A, lda, tau, C, ldc, buffer,
bufSize[], devinfo)
unsafe_free!(buffer)

info = _getindex(devinfo, 1)
Expand Down Expand Up @@ -272,7 +272,7 @@ for (bname, fname, elty) in ((:cusolverDnSorgqr_bufferSize, :cusolverDnSorgqr, :

buffer = CuArray{$elty}(undef, bufSize[])
devinfo = CuArray{Cint}(undef, 1)
$fname(dense_handle(), m, n, k, A, lda, tau, buffer, bufSize[], devinfo)
@sync $fname(dense_handle(), m, n, k, A, lda, tau, buffer, bufSize[], devinfo)
unsafe_free!(buffer)

info = _getindex(devinfo, 1)
Expand Down Expand Up @@ -310,7 +310,7 @@ for (bname, fname, elty, relty) in ((:cusolverDnSgebrd_bufferSize, :cusolverDnSg
E = CuArrays.zeros($relty, k)
TAUQ = CuArray{$elty}(undef, k)
TAUP = CuArray{$elty}(undef, k)
$fname(dense_handle(), m, n, A, lda, D, E, TAUQ, TAUP, buffer, bufSize[], devinfo)
@sync $fname(dense_handle(), m, n, A, lda, D, E, TAUQ, TAUP, buffer, bufSize[], devinfo)
unsafe_free!(buffer)

info = _getindex(devinfo, 1)
Expand Down Expand Up @@ -364,8 +364,8 @@ for (bname, fname, elty, relty) in ((:cusolverDnSgesvd_bufferSize, :cusolverDnSg
work = CuArray{$elty}(undef, lwork[])
rwork = CuArray{$relty}(undef, min(m, n) - 1)
devinfo = CuArray{Cint}(undef, 1)
$fname(dense_handle(), jobu, jobvt, m, n, A, lda, S, U, ldu, Vt, ldvt,
work, lwork[], rwork, devinfo)
@sync $fname(dense_handle(), jobu, jobvt, m, n, A, lda, S, U, ldu, Vt, ldvt,
work, lwork[], rwork, devinfo)
unsafe_free!(work)
unsafe_free!(rwork)

Expand Down Expand Up @@ -423,8 +423,8 @@ for (bname, fname, elty, relty) in ((:cusolverDnSgesvdj_bufferSize, :cusolverDnS

work = CuArray{$elty}(undef, lwork[])
devinfo = CuArray{Cint}(undef, 1)
$fname(dense_handle(), cujobz, econ, m, n, A, lda, S, U, ldu, V, ldv,
work, lwork[], devinfo, params[])
@sync $fname(dense_handle(), cujobz, econ, m, n, A, lda, S, U, ldu, V, ldv,
work, lwork[], devinfo, params[])
unsafe_free!(work)

info = _getindex(devinfo, 1)
Expand Down Expand Up @@ -459,8 +459,8 @@ for (jname, bname, fname, elty, relty) in ((:syevd!, :cusolverDnSsyevd_bufferSiz

buffer = CuArray{$elty}(undef, bufSize[])
devinfo = CuArray{Cint}(undef, 1)
$fname(dense_handle(), cujobz, cuuplo, n, A, lda, W,
buffer, bufSize[], devinfo)
@sync $fname(dense_handle(), cujobz, cuuplo, n, A, lda, W,
buffer, bufSize[], devinfo)
unsafe_free!(buffer)

info = _getindex(devinfo, 1)
Expand Down Expand Up @@ -505,8 +505,8 @@ for (jname, bname, fname, elty, relty) in ((:sygvd!, :cusolverDnSsygvd_bufferSiz

buffer = CuArray{$elty}(undef, bufSize[])
devinfo = CuArray{Cint}(undef, 1)
$fname(dense_handle(), cuitype, cujobz, cuuplo, n, A, lda, B, ldb, W,
buffer, bufSize[], devinfo)
@sync $fname(dense_handle(), cuitype, cujobz, cuuplo, n, A, lda, B, ldb, W,
buffer, bufSize[], devinfo)
unsafe_free!(buffer)

info = _getindex(devinfo, 1)
Expand Down Expand Up @@ -558,8 +558,8 @@ for (jname, bname, fname, elty, relty) in ((:sygvj!, :cusolverDnSsygvj_bufferSiz

buffer = CuArray{$elty}(undef, bufSize[])
devinfo = CuArray{Cint}(undef, 1)
$fname(dense_handle(), cuitype, cujobz, cuuplo, n, A, lda, B, ldb, W,
buffer, bufSize[], devinfo, params[])
@sync $fname(dense_handle(), cuitype, cujobz, cuuplo, n, A, lda, B, ldb, W,
buffer, bufSize[], devinfo, params[])
unsafe_free!(buffer)

info = _getindex(devinfo, 1)
Expand Down
2 changes: 1 addition & 1 deletion src/solver/highlevel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ LinearAlgebra.lmul!(trA::Transpose{T,<:CuQRPackedQ{T,S}}, B::CuVecOrMat{T}) wher
function Base.getindex(A::CuQRPackedQ{T, S}, i::Integer, j::Integer) where {T, S}
x = CuArrays.zeros(T, size(A, 2))
x[j] = 1
lmul!(A, x)
@sync lmul!(A, x)
return _getindex(x, i)
end

Expand Down