Skip to content

Commit 94461cf

Browse files
committed
WIP: Refactor memory hierarchy.
1 parent abc5009 commit 94461cf

File tree

14 files changed

+149
-165
lines changed

14 files changed

+149
-165
lines changed

LocalPreferences.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
[OpenCL]
22
# Which memory back-end to use for unspecified CLArray allocations. This can be:
33
# - "usm": Unified Shared Memory (`cl_intel_unified_shared_memory`)
4-
# - "bda": Buffer Device Address (`cl_mem` + `cl_ext_buffer_device_address`)
4+
# - "bda": plain buffers (`cl_mem` + `cl_ext_buffer_device_address`)
55
# - "svm": Shared Virtual Memory (coarse-grained)
66
# If unspecified, the default will be used based on the platform and device capabilities.
77
#default_memory_backend="..."

lib/cl/CL.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,7 @@ include("device.jl")
2020
include("context.jl")
2121
include("cmdqueue.jl")
2222
include("event.jl")
23-
include("buffer.jl")
24-
include("memory/memory.jl")
23+
include("memory.jl")
2524
include("program.jl")
2625
include("kernel.jl")
2726

lib/cl/kernel.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ function set_arg!(k::Kernel, idx::Integer, arg::AbstractMemory)
7979
clSetKernelArgSVMPointer(k, idx - 1, pointer(arg))
8080
elseif arg isa UnifiedMemory
8181
clSetKernelArgMemPointerINTEL(k, idx - 1, pointer(arg))
82-
elseif arg isa BufferDeviceMemory
82+
elseif arg isa Buffer
8383
clSetKernelArgDevicePointerEXT(k, idx - 1, pointer(arg))
8484
else
8585
error("Unknown memory type")
@@ -203,7 +203,7 @@ function call(
203203

204204
if memory isa SharedVirtualMemory
205205
push!(svm_pointers, ptr)
206-
elseif memory isa BufferDeviceMemory
206+
elseif memory isa Buffer
207207
push!(bda_pointers, ptr)
208208
elseif memory isa UnifiedDeviceMemory
209209
device_access = true

lib/cl/memory.jl

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
# Raw memory management
2+
3+
abstract type AbstractMemoryObject <: CLObject end
4+
abstract type AbstractPointerMemory end
5+
const AbstractMemory = Union{AbstractMemoryObject, AbstractPointerMemory}
6+
7+
# this will be specialized for each memory type
8+
Base.convert(T::Type{<:Union{Ptr, CLPtr}}, mem::AbstractMemory) =
9+
throw(ArgumentError("Illegal conversion of a $(typeof(mem)) to a $T"))
10+
11+
# ccall integration
12+
#
13+
# taking the pointer of a memory object means returning the underlying pointer,
14+
# and not the pointer of the object itself.
15+
Base.unsafe_convert(P::Type{<:Union{Ptr, CLPtr}}, mem::AbstractMemory) = convert(P, mem)
16+
17+
18+
## opaque memory objects
19+
20+
# This should be implemented by all subtypes
21+
#type MemoryType <: AbstractMemoryObject
22+
# id::cl_mem
23+
# ...
24+
#end
25+
26+
Base.sizeof(mem::AbstractMemoryObject) = mem.size
27+
28+
release(mem::AbstractMemoryObject) = clReleaseMemObject(mem)
29+
30+
function Base.getproperty(mem::AbstractMemoryObject, s::Symbol)
31+
if s == :type
32+
result = Ref{cl_mem_object_type}()
33+
clGetMemObjectInfo(mem, CL_MEM_TYPE, sizeof(cl_mem_object_type), result, C_NULL)
34+
return result[]
35+
elseif s == :flags
36+
result = Ref{cl_mem_flags}()
37+
clGetMemObjectInfo(mem, CL_MEM_FLAGS, sizeof(cl_mem_flags), result, C_NULL)
38+
mf = result[]
39+
flags = Symbol[]
40+
if (mf & CL_MEM_READ_WRITE) != 0
41+
push!(flags, :rw)
42+
end
43+
if (mf & CL_MEM_WRITE_ONLY) != 0
44+
push!(flags, :w)
45+
end
46+
if (mf & CL_MEM_READ_ONLY) != 0
47+
push!(flags, :r)
48+
end
49+
if (mf & CL_MEM_USE_HOST_PTR) != 0
50+
push!(flags, :use)
51+
end
52+
if (mf & CL_MEM_ALLOC_HOST_PTR) != 0
53+
push!(flags, :alloc)
54+
end
55+
if (mf & CL_MEM_COPY_HOST_PTR) != 0
56+
push!(flags, :copy)
57+
end
58+
return tuple(flags...)
59+
elseif s == :size
60+
result = Ref{Csize_t}()
61+
clGetMemObjectInfo(mem, CL_MEM_SIZE, sizeof(Csize_t), result, C_NULL)
62+
return result[]
63+
elseif s == :reference_count
64+
result = Ref{Cuint}()
65+
clGetMemObjectInfo(mem, CL_MEM_REFERENCE_COUNT, sizeof(Cuint), result, C_NULL)
66+
return Int(result[])
67+
elseif s == :map_count
68+
result = Ref{Cuint}()
69+
clGetMemObjectInfo(mem, CL_MEM_MAP_COUNT, sizeof(Cuint), result, C_NULL)
70+
return Int(result[])
71+
elseif s == :device_address
72+
result = Ref{cl_mem_device_address_ext}()
73+
clGetMemObjectInfo(mem, CL_MEM_DEVICE_ADDRESS_EXT, sizeof(cl_mem_device_address_ext), result, C_NULL)
74+
return CLPtr{Cvoid}(result[])
75+
else
76+
return getfield(mem, s)
77+
end
78+
end
79+
80+
# for passing buffers to OpenCL APIs: use the underlying handle
81+
Base.unsafe_convert(::Type{cl_mem}, mem::AbstractMemoryObject) = mem.id
82+
83+
# for passing buffers to kernels: pass the private device pointer
84+
Base.convert(::Type{CLPtr{T}}, mem::AbstractMemoryObject) where {T} =
85+
convert(CLPtr{T}, pointer(mem))
86+
# XXX: for passing buffers directly, we can support non-BDA drivers
87+
# by postponing the conversion to `cl.set_arg!`
88+
#Base.unsafe_convert(::Type{<:Ptr}, mem::AbstractMemoryObject) = mem
89+
90+
include("memory/buffer.jl")
91+
92+
#TODO: enqueue_migrate_mem_objects(queue, mem_objects, flags=0, wait_for=None)
93+
#TODO: enqueue_migrate_mem_objects_ext(queue, mem_objects, flags=0, wait_for=None)
94+
95+
96+
## pointer-based memory
97+
98+
include("memory/usm.jl")
99+
include("memory/svm.jl")

lib/cl/memory/bda.jl

Lines changed: 0 additions & 33 deletions
This file was deleted.

lib/cl/buffer.jl renamed to lib/cl/memory/buffer.jl

Lines changed: 16 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -1,92 +1,17 @@
1-
# OpenCL Memory Object
2-
3-
abstract type AbstractMemoryObject <: CLObject end
4-
5-
#This should be implemented by all subtypes
6-
# type MemoryType <: AbstractMemoryObject
7-
# id::cl_mem
8-
# ...
9-
# end
10-
11-
# for passing buffers to OpenCL APIs: use the underlying handle
12-
Base.unsafe_convert(::Type{cl_mem}, mem::AbstractMemoryObject) = mem.id
13-
14-
# for passing buffers to kernels: keep the buffer, it's handled by `cl.set_arg!`
15-
Base.unsafe_convert(::Type{<:Ptr}, mem::AbstractMemoryObject) = mem
16-
17-
Base.sizeof(mem::AbstractMemoryObject) = mem.size
18-
19-
release(mem::AbstractMemoryObject) = clReleaseMemObject(mem)
20-
21-
function Base.getproperty(mem::AbstractMemoryObject, s::Symbol)
22-
if s == :context
23-
param = Ref{cl_context}()
24-
clGetMemObjectInfo(mem, CL_MEM_CONTEXT, sizeof(cl_context), param, C_NULL)
25-
return Context(param[], retain = true)
26-
elseif s == :mem_type
27-
result = Ref{cl_mem_object_type}()
28-
clGetMemObjectInfo(mem, CL_MEM_TYPE, sizeof(cl_mem_object_type), result, C_NULL)
29-
return result[]
30-
elseif s == :mem_flags
31-
result = Ref{cl_mem_flags}()
32-
clGetMemObjectInfo(mem, CL_MEM_FLAGS, sizeof(cl_mem_flags), result, C_NULL)
33-
mf = result[]
34-
flags = Symbol[]
35-
if (mf & CL_MEM_READ_WRITE) != 0
36-
push!(flags, :rw)
37-
end
38-
if (mf & CL_MEM_WRITE_ONLY) != 0
39-
push!(flags, :w)
40-
end
41-
if (mf & CL_MEM_READ_ONLY) != 0
42-
push!(flags, :r)
43-
end
44-
if (mf & CL_MEM_USE_HOST_PTR) != 0
45-
push!(flags, :use)
46-
end
47-
if (mf & CL_MEM_ALLOC_HOST_PTR) != 0
48-
push!(flags, :alloc)
49-
end
50-
if (mf & CL_MEM_COPY_HOST_PTR) != 0
51-
push!(flags, :copy)
52-
end
53-
return tuple(flags...)
54-
elseif s == :size
55-
result = Ref{Csize_t}()
56-
clGetMemObjectInfo(mem, CL_MEM_SIZE, sizeof(Csize_t), result, C_NULL)
57-
return result[]
58-
elseif s == :reference_count
59-
result = Ref{Cuint}()
60-
clGetMemObjectInfo(mem, CL_MEM_REFERENCE_COUNT, sizeof(Cuint), result, C_NULL)
61-
return Int(result[])
62-
elseif s == :map_count
63-
result = Ref{Cuint}()
64-
clGetMemObjectInfo(mem, CL_MEM_MAP_COUNT, sizeof(Cuint), result, C_NULL)
65-
return Int(result[])
66-
elseif s == :device_address
67-
result = Ref{cl_mem_device_address_ext}()
68-
clGetMemObjectInfo(mem, CL_MEM_DEVICE_ADDRESS_EXT, sizeof(cl_mem_device_address_ext), result, C_NULL)
69-
return CLPtr{Cvoid}(result[])
70-
else
71-
return getfield(mem, s)
72-
end
73-
end
74-
75-
# convenience functions
76-
context(mem::AbstractMemoryObject) = mem.context
77-
Base.pointer(mem::AbstractMemoryObject) = mem.pointer
78-
79-
#TODO: enqueue_migrate_mem_objects(queue, mem_objects, flags=0, wait_for=None)
80-
#TODO: enqueue_migrate_mem_objects_ext(queue, mem_objects, flags=0, wait_for=None)
81-
821
# OpenCL.Buffer
832

843
struct Buffer <: AbstractMemoryObject
854
id::cl_mem
5+
ptr::Union{Nothing,CLPtr{Cvoid}}
866
bytesize::Int
7+
context::Context
878
end
889

10+
Buffer() = Buffer(C_NULL, nothing, 0, context())
11+
12+
Base.pointer(buf::Buffer) = @something buf.ptr error("Buffer does not have a device private address")
8913
Base.sizeof(buf::Buffer) = buf.bytesize
14+
context(buf::Buffer) = buf.context
9015

9116

9217
## constructors
@@ -130,7 +55,16 @@ function Buffer(sz::Int, flags::Integer, hostbuf=nothing;
13055
if err_code[] != CL_SUCCESS
13156
throw(CLError(err_code[]))
13257
end
133-
return Buffer(mem_id, sz)
58+
59+
ptr = if device_private_address
60+
ptr_ref = Ref{cl_mem_device_address_ext}()
61+
clGetMemObjectInfo(mem_id, CL_MEM_DEVICE_ADDRESS_EXT, sizeof(cl_mem_device_address_ext), ptr_ref, C_NULL)
62+
CLPtr{Cvoid}(ptr_ref[])
63+
else
64+
nothing
65+
end
66+
67+
return Buffer(mem_id, ptr, sz, context())
13468
end
13569

13670
# allocated buffer

lib/cl/memory/memory.jl

Lines changed: 0 additions & 22 deletions
This file was deleted.

lib/cl/memory/svm.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
struct SharedVirtualMemory <: AbstractMemory
1+
struct SharedVirtualMemory <: AbstractPointerMemory
22
ptr::CLPtr{Cvoid}
33
bytesize::Int
44
context::Context

lib/cl/memory/usm.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
abstract type UnifiedMemory <: AbstractMemory end
1+
abstract type UnifiedMemory <: AbstractPointerMemory end
22

33
function usm_free(mem::UnifiedMemory; blocking::Bool = false)
44
if blocking

lib/cl/state.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -199,10 +199,10 @@ function default_memory_backend(dev::Device)
199199

200200
backend = if backend_str == "usm"
201201
USMBackend()
202-
elseif backend_str == "bda"
203-
BDABackend()
204202
elseif backend_str == "svm"
205203
SVMBackend()
204+
elseif backend_str == "bda"
205+
BDABackend()
206206
else
207207
error("Unknown memory backend '$backend_str' requested")
208208
end

0 commit comments

Comments
 (0)