Skip to content

Commit 126e49b

Browse files
VarLadmaleadt
andauthored
Revamp memory management. (#264)
Adds support for USM next to SVM, while also porting several features from CUDA.jl (CLPtr, Managed, etc). Co-authored-by: Tim Besard <tim.besard@gmail.com>
1 parent 785565c commit 126e49b

36 files changed

+2736
-1198
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
2121
Adapt = "4"
2222
GPUArrays = "11.2.1"
2323
GPUCompiler = "0.27, 1"
24-
KernelAbstractions = "0.9.1"
24+
KernelAbstractions = "0.9.2"
2525
LLVM = "9.1"
2626
LinearAlgebra = "1"
2727
OpenCL_jll = "=2024.5.8"

examples/vadd.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ prog = cl.Program(; source) |> cl.build!
2121
kern = cl.Kernel(prog, "vadd")
2222

2323
len = prod(dims)
24-
clcall(kern, Tuple{Ptr{Float32}, Ptr{Float32}, Ptr{Float32}},
24+
clcall(kern, Tuple{CLPtr{Float32}, CLPtr{Float32}, CLPtr{Float32}},
2525
d_a, d_b, d_c; global_size=(len,))
2626
c = Array(d_c)
2727
@test a+b c

lib/cl/CL.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
module cl
22

3+
using Printf
4+
5+
include("pointer.jl")
36
include("api.jl")
47

58
# OpenCL wrapper objects are expected to have an `id` field containing a handle pointer
@@ -15,9 +18,8 @@ include("device.jl")
1518
include("context.jl")
1619
include("cmdqueue.jl")
1720
include("event.jl")
18-
include("memory.jl")
21+
include("memory/memory.jl")
1922
include("buffer.jl")
20-
include("svm.jl")
2123
include("program.jl")
2224
include("kernel.jl")
2325

lib/cl/api.jl

Lines changed: 16 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,22 @@ function retry_reclaim(f, isfailed)
6969
ret
7070
end
7171

72+
macro ext_ccall(ex)
73+
# decode the expression
74+
@assert Meta.isexpr(ex, :(::))
75+
call, ret = ex.args
76+
@assert Meta.isexpr(call, :call)
77+
target, argexprs... = call.args
78+
@assert Meta.isexpr(target, :(.))
79+
_, fn = target.args
80+
81+
@gensym fptr
82+
esc(quote
83+
$fptr = $clGetExtensionFunctionAddressForPlatform(platform(), $fn)
84+
@ccall $(Expr(:($), fptr))($(argexprs...))::$ret
85+
end)
86+
end
87+
7288
include("libopencl.jl")
7389

7490
@static if Sys.iswindows()
@@ -176,22 +192,4 @@ function __init__()
176192
if !OpenCL_jll.is_available()
177193
@error "OpenCL_jll is not available for your platform, OpenCL.jl. will not work."
178194
end
179-
180-
# ensure that operations executed by the REPL back-end finish before returning,
181-
# because displaying values happens on a different task
182-
if isdefined(Base, :active_repl_backend) && !isnothing(Base.active_repl_backend)
183-
push!(Base.active_repl_backend.ast_transforms, synchronize_opencl_tasks)
184-
end
185-
end
186-
187-
function synchronize_opencl_tasks(ex)
188-
quote
189-
try
190-
$(ex)
191-
finally
192-
if haskey($task_local_storage(), :CLDevice)
193-
$device_synchronize()
194-
end
195-
end
196-
end
197195
end

lib/cl/buffer.jl

Lines changed: 74 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,79 @@
1+
# OpenCL Memory Object
2+
3+
abstract type AbstractMemoryObject <: CLObject end
4+
5+
#This should be implemented by all subtypes
6+
# type MemoryType <: AbstractMemoryObject
7+
# id::cl_mem
8+
# ...
9+
# end
10+
11+
# for passing buffers to OpenCL APIs: use the underlying handle
12+
Base.unsafe_convert(::Type{cl_mem}, mem::AbstractMemoryObject) = mem.id
13+
14+
# for passing buffers to kernels: keep the buffer, it's handled by `cl.set_arg!`
15+
Base.unsafe_convert(::Type{<:Ptr}, mem::AbstractMemoryObject) = mem
16+
17+
Base.sizeof(mem::AbstractMemoryObject) = mem.size
18+
19+
context(mem::AbstractMemoryObject) = mem.context
20+
21+
function Base.getproperty(mem::AbstractMemoryObject, s::Symbol)
22+
if s == :context
23+
param = Ref{cl_context}()
24+
clGetMemObjectInfo(mem, CL_MEM_CONTEXT, sizeof(cl_context), param, C_NULL)
25+
return Context(param[], retain = true)
26+
elseif s == :mem_type
27+
result = Ref{cl_mem_object_type}()
28+
clGetMemObjectInfo(mem, CL_MEM_TYPE, sizeof(cl_mem_object_type), result, C_NULL)
29+
return result[]
30+
elseif s == :mem_flags
31+
result = Ref{cl_mem_flags}()
32+
clGetMemObjectInfo(mem, CL_MEM_FLAGS, sizeof(cl_mem_flags), result, C_NULL)
33+
mf = result[]
34+
flags = Symbol[]
35+
if (mf & CL_MEM_READ_WRITE) != 0
36+
push!(flags, :rw)
37+
end
38+
if (mf & CL_MEM_WRITE_ONLY) != 0
39+
push!(flags, :w)
40+
end
41+
if (mf & CL_MEM_READ_ONLY) != 0
42+
push!(flags, :r)
43+
end
44+
if (mf & CL_MEM_USE_HOST_PTR) != 0
45+
push!(flags, :use)
46+
end
47+
if (mf & CL_MEM_ALLOC_HOST_PTR) != 0
48+
push!(flags, :alloc)
49+
end
50+
if (mf & CL_MEM_COPY_HOST_PTR) != 0
51+
push!(flags, :copy)
52+
end
53+
return tuple(flags...)
54+
elseif s == :size
55+
result = Ref{Csize_t}()
56+
clGetMemObjectInfo(mem, CL_MEM_SIZE, sizeof(Csize_t), result, C_NULL)
57+
return result[]
58+
elseif s == :reference_count
59+
result = Ref{Cuint}()
60+
clGetMemObjectInfo(mem, CL_MEM_REFERENCE_COUNT, sizeof(Cuint), result, C_NULL)
61+
return Int(result[])
62+
elseif s == :map_count
63+
result = Ref{Cuint}()
64+
clGetMemObjectInfo(mem, CL_MEM_MAP_COUNT, sizeof(Cuint), result, C_NULL)
65+
return Int(result[])
66+
else
67+
return getfield(mem, s)
68+
end
69+
end
70+
71+
#TODO: enqueue_migrate_mem_objects(queue, mem_objects, flags=0, wait_for=None)
72+
#TODO: enqueue_migrate_mem_objects_ext(queue, mem_objects, flags=0, wait_for=None)
73+
174
# OpenCL.Buffer
275

3-
mutable struct Buffer{T} <: AbstractMemory
76+
mutable struct Buffer{T} <: AbstractMemoryObject
477
const id::cl_mem
578
const len::Int
679

lib/cl/cmdqueue.jl

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,14 @@
22

33
mutable struct CmdQueue <: CLObject
44
const id::cl_command_queue
5+
Base.@atomic valid::Bool
56

67
function CmdQueue(q_id::cl_command_queue; retain::Bool=false)
7-
q = new(q_id)
8+
q = new(q_id, true)
89
retain && clRetainCommandQueue(q)
910
finalizer(q) do _
10-
# this is to prevent `device_synchronize()` operating on freed queues.
11-
# XXX: why does the WeakKeyDict contain freed objects?
12-
delete!(cl.queues, q)
13-
clReleaseCommandQueue(q)
11+
Base.@atomic q.valid = false
12+
clReleaseCommandQueue(q)
1413
end
1514
return q
1615
end

lib/cl/device.jl

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,59 @@ function exec_capabilities(d::Device)
190190
)
191191
end
192192

193+
function usm_supported(d::Device)
194+
"cl_intel_unified_shared_memory" in d.extensions || return false
195+
return true
196+
end
197+
198+
function usm_capabilities(d::Device)
199+
usm_supported(d) || throw(ArgumentError("Unified Shared Memory not supported on this device"))
200+
201+
function check_capability_bits(mask::cl_device_unified_shared_memory_capabilities_intel)
202+
(;
203+
access = mask & CL_UNIFIED_SHARED_MEMORY_ACCESS_INTEL != 0,
204+
atomic_access = mask & CL_UNIFIED_SHARED_MEMORY_ATOMIC_ACCESS_INTEL != 0,
205+
concurrent_access = mask & CL_UNIFIED_SHARED_MEMORY_CONCURRENT_ACCESS_INTEL != 0,
206+
concurrent_atomic_access = mask & CL_UNIFIED_SHARED_MEMORY_CONCURRENT_ATOMIC_ACCESS_INTEL != 0,
207+
)
208+
end
209+
210+
host = Ref{cl_device_unified_shared_memory_capabilities_intel}()
211+
device = Ref{cl_device_unified_shared_memory_capabilities_intel}()
212+
single_device = Ref{cl_device_unified_shared_memory_capabilities_intel}()
213+
shared = Ref{cl_device_unified_shared_memory_capabilities_intel}()
214+
cross_device = Ref{cl_device_unified_shared_memory_capabilities_intel}()
215+
216+
clGetDeviceInfo(
217+
d, CL_DEVICE_HOST_MEM_CAPABILITIES_INTEL,
218+
sizeof(cl_device_unified_shared_memory_capabilities_intel), host, C_NULL
219+
)
220+
clGetDeviceInfo(
221+
d, CL_DEVICE_DEVICE_MEM_CAPABILITIES_INTEL,
222+
sizeof(cl_device_unified_shared_memory_capabilities_intel), device, C_NULL
223+
)
224+
clGetDeviceInfo(
225+
d, CL_DEVICE_SINGLE_DEVICE_SHARED_MEM_CAPABILITIES_INTEL,
226+
sizeof(cl_device_unified_shared_memory_capabilities_intel), single_device, C_NULL
227+
)
228+
clGetDeviceInfo(
229+
d, CL_DEVICE_SHARED_SYSTEM_MEM_CAPABILITIES_INTEL,
230+
sizeof(cl_device_unified_shared_memory_capabilities_intel), shared, C_NULL
231+
)
232+
clGetDeviceInfo(
233+
d, CL_DEVICE_CROSS_DEVICE_SHARED_MEM_CAPABILITIES_INTEL,
234+
sizeof(cl_device_unified_shared_memory_capabilities_intel), cross_device, C_NULL
235+
)
236+
237+
return (;
238+
host = check_capability_bits(host[]),
239+
device = check_capability_bits(device[]),
240+
single_device = check_capability_bits(single_device[]),
241+
shared = check_capability_bits(shared[]),
242+
cross_device = check_capability_bits(cross_device[]),
243+
)
244+
end
245+
193246
function svm_capabilities(d::Device)
194247
result = Ref{cl_device_svm_capabilities}()
195248
clGetDeviceInfo(d, CL_DEVICE_SVM_CAPABILITIES,

0 commit comments

Comments
 (0)