Skip to content

Commit 25dd395

Browse files
VarLadmaleadt
andauthored
Allow direct host-interactions with SVM-backed arrays (#336)
Co-authored-by: Tim Besard <tim.besard@gmail.com>
1 parent 67a0fa7 commit 25dd395

File tree

8 files changed

+63
-27
lines changed

8 files changed

+63
-27
lines changed

.github/workflows/Test.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ jobs:
1515
test:
1616
name: Julia ${{ matrix.version }} - ${{ matrix.os }} - ${{ matrix.arch }} - ${{ matrix.memory_backend }} - PoCL ${{ matrix.pocl }}
1717
runs-on: ${{ matrix.os }}
18-
timeout-minutes: 180
18+
timeout-minutes: 60
1919
permissions: # needed to allow julia-actions/cache to proactively delete old caches that it has created
2020
actions: write
2121
contents: read

lib/cl/memory/svm.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ Base.show(io::IO, mem::SharedVirtualMemory) =
4545
@printf(io, "SharedVirtualMemory(%s at %p)", Base.format_bytes(sizeof(mem)), Int(pointer(mem)))
4646

4747
Base.convert(::Type{Ptr{T}}, mem::SharedVirtualMemory) where {T} =
48-
convert(Ptr{T}, pointer(mem))
48+
convert(Ptr{T}, reinterpret(Ptr{Cvoid}, pointer(mem)))
4949

5050
Base.convert(::Type{CLPtr{T}}, mem::SharedVirtualMemory) where {T} =
5151
reinterpret(CLPtr{T}, pointer(mem))

src/array.jl

Lines changed: 13 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
export CLArray, CLVector, CLMatrix, CLVecOrMat, is_device, is_shared, is_host
1+
export CLArray, CLVector, CLMatrix, CLVecOrMat,
2+
device_accessible, host_accessible
23

34

45
## array type
@@ -176,11 +177,9 @@ memtype(x::CLArray) = memtype(typeof(x))
176177
memtype(::Type{<:CLArray{<:Any, <:Any, M}}) where {M} = @isdefined(M) ? M : Any
177178

178179
# can we read this array from the device (i.e. derive a CLPtr)?
179-
is_device(a::CLArray) =
180+
device_accessible(a::CLArray) =
180181
memtype(a) in (cl.UnifiedDeviceMemory, cl.UnifiedSharedMemory, cl.SharedVirtualMemory, cl.Buffer)
181-
is_shared(a::CLArray) =
182-
memtype(a) in (cl.UnifiedSharedMemory, cl.SharedVirtualMemory)
183-
is_host(a::CLArray) =
182+
host_accessible(a::CLArray) =
184183
memtype(a) in (cl.UnifiedHostMemory, cl.UnifiedSharedMemory, cl.SharedVirtualMemory)
185184

186185

@@ -272,12 +271,12 @@ Base.convert(::Type{T}, x::T) where {T <: CLArray} = x
272271

273272
## indexing
274273

275-
function Base.getindex(x::CLArray{<:Any, <:Any, <:Union{cl.UnifiedHostMemory, cl.UnifiedSharedMemory}}, I::Int)
274+
function Base.getindex(x::CLArray{<:Any, <:Any, <:Union{cl.UnifiedHostMemory, cl.UnifiedSharedMemory, cl.SharedVirtualMemory}}, I::Int)
276275
@boundscheck checkbounds(x, I)
277276
return GC.@preserve x unsafe_load(host_pointer(x, I))
278277
end
279278

280-
function Base.setindex!(x::CLArray{<:Any, <:Any, <:Union{cl.UnifiedHostMemory, cl.UnifiedSharedMemory}}, v, I::Int)
279+
function Base.setindex!(x::CLArray{<:Any, <:Any, <:Union{cl.UnifiedHostMemory, cl.UnifiedSharedMemory, cl.SharedVirtualMemory}}, v, I::Int)
281280
@boundscheck checkbounds(x, I)
282281
return GC.@preserve x unsafe_store!(host_pointer(x, I), v)
283282
end
@@ -286,14 +285,14 @@ end
286285
## interop with libraries
287286

288287
function Base.unsafe_convert(::Type{Ptr{T}}, x::CLArray{T}) where {T}
289-
if !is_host(x)
288+
if !host_accessible(x)
290289
throw(ArgumentError("cannot take the CPU address of a $(typeof(x))"))
291290
end
292291
return convert(Ptr{T}, x.data[]) + x.offset * Base.elsize(x)
293292
end
294293

295294
function Base.unsafe_convert(::Type{CLPtr{T}}, x::CLArray{T}) where {T}
296-
if !is_device(x)
295+
if !device_accessible(x)
297296
throw(ArgumentError("cannot take the device address of a $(typeof(x))"))
298297
end
299298
return convert(CLPtr{T}, x.data[]) + x.offset * Base.elsize(x)
@@ -485,16 +484,14 @@ Base.unsafe_convert(::Type{CLPtr{T}}, A::PermutedDimsArray) where {T} =
485484
## unsafe_wrap
486485

487486
"""
488-
unsafe_wrap(Array, arr::CLArray{_,_,cl.UnifiedSharedMemory})
487+
unsafe_wrap(Array, arr::CLArray)
489488
490489
Wrap a Julia `Array` around the buffer that backs a `CLArray`. This is only possible if the
491-
GPU array is backed by a shared buffer, i.e. if it was created with `CLArray{T}(undef, ...)`.
490+
GPU array is backed by host memory, such as unified (host or shared) memory, or shared
491+
virtual memory.
492492
"""
493-
function Base.unsafe_wrap(::Type{Array}, arr::CLArray{T, N, cl.UnifiedSharedMemory}) where {T, N}
494-
# TODO: can we make this more convenient by increasing the buffer's refcount and using
495-
# a finalizer on the Array? does that work when taking views etc of the Array?
496-
ptr = reinterpret(Ptr{T}, pointer(arr))
497-
return unsafe_wrap(Array, ptr, size(arr))
493+
function Base.unsafe_wrap(::Type{Array}, arr::CLArray{T, N}) where {T, N}
494+
return unsafe_wrap(Array, host_pointer(arr), size(arr))
498495
end
499496

500497

src/memory.jl

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,15 @@ mutable struct Managed{M}
1515
# whether there are outstanding operations that haven't been synchronized
1616
dirty::Bool
1717

18-
function Managed(mem::cl.AbstractMemory; queue = cl.queue(), dirty = true)
18+
# who is currently using the memory
19+
user::Symbol
20+
21+
function Managed(mem::cl.AbstractMemory; queue = cl.queue(), dirty = true, user = :device)
1922
# NOTE: memory starts as dirty, because stream-ordered allocations are only
2023
# guaranteed to be physically allocated at a synchronization event.
21-
return new{typeof(mem)}(mem, queue, dirty)
24+
# NOTE: memory also starts as device-owned, because we need to map it as soon as
25+
# the host accesses it.
26+
return new{typeof(mem)}(mem, queue, dirty, user)
2227
end
2328
end
2429

@@ -38,7 +43,7 @@ function maybe_synchronize(managed::Managed)
3843
return nothing
3944
end
4045

41-
function Base.convert(typ::Union{Type{<:CLPtr}, Type{cl.Buffer}}, managed::Managed)
46+
function Base.convert(typ::Union{Type{<:CLPtr}, Type{cl.Buffer}}, managed::Managed{M}) where {M}
4247
# let null pointers pass through as-is
4348
# XXX: does not work for buffers
4449
ptr = convert(typ, managed.mem)
@@ -52,6 +57,13 @@ function Base.convert(typ::Union{Type{<:CLPtr}, Type{cl.Buffer}}, managed::Manag
5257
managed.queue = cl.queue()
5358
end
5459

60+
# coarse-grained SVM needs to be unmapped when accessing it back from the device
61+
# TODO: support fine-grained SVM
62+
if M == cl.SharedVirtualMemory && managed.user == :host
63+
cl.enqueue_svm_unmap(pointer(managed.mem))
64+
managed.user = :device
65+
end
66+
5567
managed.dirty = true
5668
return ptr
5769
end
@@ -74,6 +86,14 @@ function Base.convert(typ::Type{<:Ptr}, managed::Managed{M}) where {M}
7486

7587
# make sure any work on the memory has finished.
7688
maybe_synchronize(managed)
89+
90+
# coarse-grained SVM needs to be mapped when initially accessing it from the host
91+
# TODO: support fine-grained SVM
92+
if M == cl.SharedVirtualMemory && managed.user != :host
93+
cl.enqueue_svm_map(pointer(managed.mem), sizeof(managed.mem), :rw; blocking = true)
94+
managed.user = :host
95+
end
96+
7797
return ptr
7898
end
7999

@@ -162,6 +182,10 @@ function free(managed::Managed)
162182
end
163183

164184
if mem isa cl.SharedVirtualMemory
185+
if managed.user == :host
186+
# if the coarse-grained SVM buffer is mapped on the host, unmap it first.
187+
cl.enqueue_svm_unmap(pointer(mem))
188+
end
165189
cl.svm_free(mem)
166190
elseif mem isa cl.UnifiedMemory
167191
cl.usm_free(mem)

test/Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
88
JLD2 = "033835bb-8acc-5ee8-8aae-3f567f8a3819"
99
KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
1010
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
11+
Preferences = "21216c6a-2e73-6563-6e65-726566657250"
1112
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
1213
REPL = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb"
1314
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"

test/array.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,10 @@ import Adapt
99
@test Base.elsize(xs) == sizeof(Int)
1010
@test CLArray{Int, 2}(xs) === xs
1111

12-
@test_throws ArgumentError Base.unsafe_convert(Ptr{Int}, xs)
13-
@test_throws ArgumentError Base.unsafe_convert(Ptr{Float32}, xs)
12+
if !host_accessible(xs)
13+
@test_throws ArgumentError Base.unsafe_convert(Ptr{Int}, xs)
14+
@test_throws ArgumentError Base.unsafe_convert(Ptr{Float32}, xs)
15+
end
1416

1517
@test collect(OpenCL.zeros(Float32, 2, 2)) == zeros(Float32, 2, 2)
1618
@test collect(OpenCL.ones(Float32, 2, 2)) == ones(Float32, 2, 2)

test/pointer.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,17 +37,17 @@ end
3737

3838

3939
@testset "GPU or CPU integration" begin
40-
4140
a = [1]
4241
ccall(:clock, Nothing, (Ptr{Int},), a)
4342
@test_throws Exception ccall(:clock, Nothing, (CLPtr{Int},), a)
4443
ccall(:clock, Nothing, (PtrOrCLPtr{Int},), a)
4544

4645
b = CLArray{eltype(a), ndims(a)}(undef, size(a))
4746
ccall(:clock, Nothing, (CLPtr{Int},), b)
48-
@test_throws Exception ccall(:clock, Nothing, (Ptr{Int},), b)
47+
if !host_accessible(b)
48+
@test_throws Exception ccall(:clock, Nothing, (Ptr{Int},), b)
49+
end
4950
ccall(:clock, Nothing, (PtrOrCLPtr{Int},), b)
50-
5151
end
5252

5353

test/runtests.jl

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ using Dates
33
import REPL
44
using Printf: @sprintf
55
using Base.Filesystem: path_separator
6+
using Preferences
67

78
# parse some command-line arguments
89
function extract_flag!(args, flag, default=nothing)
@@ -110,7 +111,18 @@ if !isempty(optlike_args)
110111
error("Unknown test options `$(join(optlike_args, " "))` (try `--help` for usage instructions)")
111112
end
112113
## the remaining args filter tests
113-
if !isempty(ARGS)
114+
if isempty(ARGS)
115+
# default to running all tests, except:
116+
filter!(tests) do test
117+
if load_preference(OpenCL, "default_memory_backend") == "svm" &&
118+
test == "gpuarrays/indexing scalar"
119+
# GPUArrays' scalar indexing tests assume that indexing is not supported
120+
return false
121+
end
122+
123+
return true
124+
end
125+
else
114126
filter!(tests) do test
115127
any(arg->startswith(test, arg), ARGS)
116128
end

0 commit comments

Comments
 (0)