Skip to content

Commit 752a953

Browse files
authored
Port storage handling + wrapper materialization from CUDA.jl. (#468)
1 parent 15969e9 commit 752a953

File tree

6 files changed

+577
-44
lines changed

6 files changed

+577
-44
lines changed

lib/JLArrays/src/JLArrays.jl

Lines changed: 87 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
module JLArrays
77

8-
export JLArray, jl
8+
export JLArray, JLVector, JLMatrix, jl
99

1010
using GPUArrays
1111

@@ -86,18 +86,27 @@ end
8686
# array type
8787

8888
struct JLDeviceArray{T, N} <: AbstractDeviceArray{T, N}
89-
data::Array{T, N}
89+
data::Vector{UInt8}
90+
offset::Int
9091
dims::Dims{N}
91-
92-
function JLDeviceArray{T,N}(data::Array{T, N}, dims::Dims{N}) where {T,N}
93-
new(data, dims)
94-
end
9592
end
9693

94+
Base.elsize(::Type{<:JLDeviceArray{T}}) where {T} = sizeof(T)
95+
9796
Base.size(x::JLDeviceArray) = x.dims
97+
Base.sizeof(x::JLDeviceArray) = Base.elsize(x) * length(x)
98+
99+
Base.unsafe_convert(::Type{Ptr{T}}, x::JLDeviceArray{T}) where {T} =
100+
Base.unsafe_convert(Ptr{T}, x.data) + x.offset*Base.elsize(x)
101+
102+
# conversion of untyped data to a typed Array
103+
function typed_data(x::JLDeviceArray{T}) where {T}
104+
unsafe_wrap(Array, pointer(x), x.dims)
105+
end
106+
107+
@inline Base.getindex(A::JLDeviceArray, index::Integer) = getindex(typed_data(A), index)
108+
@inline Base.setindex!(A::JLDeviceArray, x, index::Integer) = setindex!(typed_data(A), x, index)
98109

99-
@inline Base.getindex(A::JLDeviceArray, index::Integer) = getindex(A.data, index)
100-
@inline Base.setindex!(A::JLDeviceArray, x, index::Integer) = setindex!(A.data, x, index)
101110

102111
# indexing
103112

@@ -139,23 +148,60 @@ end
139148
# Host abstractions
140149
#
141150

142-
struct JLArray{T, N} <: AbstractGPUArray{T, N}
143-
data::Array{T, N}
151+
function check_eltype(T)
152+
if !Base.allocatedinline(T)
153+
explanation = explain_allocatedinline(T)
154+
error("""
155+
JLArray only supports element types that are allocated inline.
156+
$explanation""")
157+
end
158+
end
159+
160+
mutable struct JLArray{T, N} <: AbstractGPUArray{T, N}
161+
data::DataRef{Vector{UInt8}}
162+
163+
offset::Int # offset of the data in the buffer, in number of elements
164+
144165
dims::Dims{N}
145166

146-
function JLArray{T,N}(data::Array{T, N}, dims::Dims{N}) where {T,N}
147-
isbitstype(T) || error("JLArray only supports bits types")
148-
# when supporting isbits-union types, use `Base.allocatedinline` here.
149-
new(data, dims)
167+
# allocating constructor
168+
function JLArray{T,N}(::UndefInitializer, dims::Dims{N}) where {T,N}
169+
check_eltype(T)
170+
maxsize = prod(dims) * sizeof(T)
171+
data = Vector{UInt8}(undef, maxsize)
172+
ref = DataRef(data)
173+
obj = new{T,N}(ref, 0, dims)
174+
finalizer(unsafe_free!, obj)
150175
end
176+
177+
# low-level constructor for wrapping existing data
178+
function JLArray{T,N}(ref::DataRef{Vector{UInt8}}, dims::Dims{N};
179+
offset::Int=0) where {T,N}
180+
check_eltype(T)
181+
obj = new{T,N}(ref, offset, dims)
182+
finalizer(unsafe_free!, obj)
183+
end
184+
end
185+
186+
unsafe_free!(a::JLArray) = GPUArrays.unsafe_free!(a.data)
187+
188+
# conversion of untyped data to a typed Array
189+
function typed_data(x::JLArray{T}) where {T}
190+
unsafe_wrap(Array, pointer(x), x.dims)
191+
end
192+
193+
function GPUArrays.derive(::Type{T}, N::Int, a::JLArray, dims::Dims, offset::Int) where {T}
194+
ref = copy(a.data)
195+
offset = (a.offset * Base.elsize(a)) ÷ sizeof(T) + offset
196+
JLArray{T,N}(ref, dims; offset)
151197
end
152198

153199

154-
## constructors
200+
## convenience constructors
155201

156-
# type and dimensionality specified, accepting dims as tuples of Ints
157-
JLArray{T,N}(::UndefInitializer, dims::Dims{N}) where {T,N} =
158-
JLArray{T,N}(Array{T, N}(undef, dims), dims)
202+
const JLVector{T} = JLArray{T,1}
203+
const JLMatrix{T} = JLArray{T,2}
204+
const JLVecOrMat{T} = Union{JLVector{T},JLMatrix{T}}
159205

160206
# type and dimensionality specified, accepting dims as series of Ints
161207
JLArray{T,N}(::UndefInitializer, dims::Integer...) where {T,N} = JLArray{T,N}(undef, dims)
@@ -172,7 +218,10 @@ Base.similar(a::JLArray{T,N}) where {T,N} = JLArray{T,N}(undef, size(a))
172218
Base.similar(a::JLArray{T}, dims::Base.Dims{N}) where {T,N} = JLArray{T,N}(undef, dims)
173219
Base.similar(a::JLArray, ::Type{T}, dims::Base.Dims{N}) where {T,N} = JLArray{T,N}(undef, dims)
174220

175-
Base.copy(a::JLArray{T,N}) where {T,N} = JLArray{T,N}(copy(a.data), size(a))
221+
function Base.copy(a::JLArray{T,N}) where {T,N}
222+
b = similar(a)
223+
@inbounds copyto!(b, a)
224+
end
176225

177226

178227
## derived types
@@ -181,31 +230,26 @@ export DenseJLArray, DenseJLVector, DenseJLMatrix, DenseJLVecOrMat,
181230
StridedJLArray, StridedJLVector, StridedJLMatrix, StridedJLVecOrMat,
182231
AnyJLArray, AnyJLVector, AnyJLMatrix, AnyJLVecOrMat
183232

184-
ContiguousSubJLArray{T,N,A<:JLArray} = Base.FastContiguousSubArray{T,N,A}
185-
186233
# dense arrays: stored contiguously in memory
187-
DenseReinterpretJLArray{T,N,A<:Union{JLArray,ContiguousSubJLArray}} =
188-
Base.ReinterpretArray{T,N,S,A} where S
189-
DenseReshapedJLArray{T,N,A<:Union{JLArray,ContiguousSubJLArray,DenseReinterpretJLArray}} =
190-
Base.ReshapedArray{T,N,A}
191-
DenseSubJLArray{T,N,A<:Union{JLArray,DenseReshapedJLArray,DenseReinterpretJLArray}} =
192-
Base.FastContiguousSubArray{T,N,A}
193-
DenseJLArray{T,N} = Union{JLArray{T,N}, DenseSubJLArray{T,N}, DenseReshapedJLArray{T,N},
194-
DenseReinterpretJLArray{T,N}}
234+
DenseJLArray{T,N} = JLArray{T,N}
195235
DenseJLVector{T} = DenseJLArray{T,1}
196236
DenseJLMatrix{T} = DenseJLArray{T,2}
197237
DenseJLVecOrMat{T} = Union{DenseJLVector{T}, DenseJLMatrix{T}}
198238

199239
# strided arrays
200-
StridedSubJLArray{T,N,A<:Union{JLArray,DenseReshapedJLArray,DenseReinterpretJLArray},
201-
I<:Tuple{Vararg{Union{Base.RangeIndex, Base.ReshapedUnitRange,
202-
Base.AbstractCartesianIndex}}}} = SubArray{T,N,A,I}
203-
StridedJLArray{T,N} = Union{JLArray{T,N}, StridedSubJLArray{T,N}, DenseReshapedJLArray{T,N},
204-
DenseReinterpretJLArray{T,N}}
240+
StridedSubJLArray{T,N,I<:Tuple{Vararg{Union{Base.RangeIndex, Base.ReshapedUnitRange,
241+
Base.AbstractCartesianIndex}}}} =
242+
SubArray{T,N,<:JLArray,I}
243+
StridedJLArray{T,N} = Union{JLArray{T,N}, StridedSubJLArray{T,N}}
205244
StridedJLVector{T} = StridedJLArray{T,1}
206245
StridedJLMatrix{T} = StridedJLArray{T,2}
207246
StridedJLVecOrMat{T} = Union{StridedJLVector{T}, StridedJLMatrix{T}}
208247

248+
Base.pointer(x::StridedJLArray{T}) where {T} = Base.unsafe_convert(Ptr{T}, x)
249+
@inline function Base.pointer(x::StridedJLArray{T}, i::Integer) where T
250+
Base.unsafe_convert(Ptr{T}, x) + Base._memory_offset(x, i)
251+
end
252+
209253
# anything that's (secretly) backed by a JLArray
210254
AnyJLArray{T,N} = Union{JLArray{T,N}, WrappedArray{T,N,JLArray,JLArray{T,N}}}
211255
AnyJLVector{T} = AnyJLArray{T,1}
@@ -221,13 +265,16 @@ Base.size(x::JLArray) = x.dims
221265
Base.sizeof(x::JLArray) = Base.elsize(x) * length(x)
222266

223267
Base.unsafe_convert(::Type{Ptr{T}}, x::JLArray{T}) where {T} =
224-
Base.unsafe_convert(Ptr{T}, x.data)
268+
Base.unsafe_convert(Ptr{T}, x.data[]) + x.offset*Base.elsize(x)
225269

226270

227271
## interop with Julia arrays
228272

229-
JLArray{T,N}(x::AbstractArray{<:Any,N}) where {T,N} =
230-
JLArray{T,N}(convert(Array{T}, x), size(x))
273+
function JLArray{T,N}(xs::AbstractArray{<:Any,N}) where {T,N}
274+
A = JLArray{T,N}(undef, size(xs))
275+
copyto!(A, convert(Array{T}, xs))
276+
return A
277+
end
231278

232279
# underspecified constructors
233280
JLArray{T}(xs::AbstractArray{S,N}) where {T,N,S} = JLArray{T,N}(xs)
@@ -345,14 +392,15 @@ end
345392
GPUArrays.backend(::Type{<:JLArray}) = JLBackend()
346393

347394
Adapt.adapt_storage(::Adaptor, x::JLArray{T,N}) where {T,N} =
348-
JLDeviceArray{T,N}(x.data, x.dims)
395+
JLDeviceArray{T,N}(x.data[], x.offset, x.dims)
349396

350397
function GPUArrays.mapreducedim!(f, op, R::AnyJLArray, A::Union{AbstractArray,Broadcast.Broadcasted};
351398
init=nothing)
352399
if init !== nothing
353400
fill!(R, init)
354401
end
355-
@allowscalar Base.reducedim!(op, R.data, map(f, A))
402+
@allowscalar Base.reducedim!(op, typed_data(R), map(f, A))
403+
R
356404
end
357405

358406
end

src/host/abstractarray.jl

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,98 @@
11
# core definition of the AbstractGPUArray type
22

33

4+
# storage handling
5+
6+
export DataRef
7+
8+
# DataRef provides a helper class to manage the storage of an array.
9+
#
10+
# There's multiple reasons we don't just put the data directly in a GPUArray struct:
11+
# - to share data between multiple arrays, e.g., to create views;
12+
# - to be able to early-free data and release GC pressure.
13+
#
14+
# To support this, wrap the data in a DataRef instead, and use it with the following methods:
15+
# - `ref[]`: get the data;
16+
# - `copy(ref)`: create a new reference, increasing the reference count;
17+
# - `unsafe_free!(ref)`: decrease the reference count, and free the data if it reaches 0.
18+
#
19+
# The contained RefCounted struct should not be used directly.
20+
21+
# shared, reference-counted state.
22+
mutable struct RefCounted{D}
23+
obj::D
24+
finalizer
25+
count::Threads.Atomic{Int}
26+
end
27+
28+
function retain(rc::RefCounted)
29+
if rc.count[] == 0
30+
throw(ArgumentError("Attempt to retain freed data."))
31+
end
32+
Threads.atomic_add!(rc.count, 1)
33+
return
34+
end
35+
36+
function release(rc::RefCounted, args...)
37+
if rc.count[] == 0
38+
throw(ArgumentError("Attempt to release freed data."))
39+
end
40+
refcount = Threads.atomic_add!(rc.count, -1)
41+
if refcount == 1 && rc.finalizer !== nothing
42+
rc.finalizer(rc.obj, args...)
43+
end
44+
return
45+
end
46+
47+
function Base.getindex(rc::RefCounted)
48+
if rc.count[] == 0
49+
throw(ArgumentError("Attempt to use freed data."))
50+
end
51+
rc.obj
52+
end
53+
54+
# per-object state, with a flag to indicate whether the object has been freed.
55+
# this is to support multiple calls to `unsafe_free!` on the same object,
56+
# while only lowering the referene count of the underlying data once.
57+
mutable struct DataRef{D}
58+
rc::RefCounted{D}
59+
freed::Bool
60+
end
61+
62+
function DataRef(finalizer, data::D) where {D}
63+
rc = RefCounted{D}(data, finalizer, Threads.Atomic{Int}(1))
64+
DataRef{D}(rc, false)
65+
end
66+
DataRef(data; kwargs...) = DataRef(nothing, data; kwargs...)
67+
68+
function Base.getindex(ref::DataRef)
69+
if ref.freed
70+
throw(ArgumentError("Attempt to use a freed reference."))
71+
end
72+
ref.rc[]
73+
end
74+
75+
function Base.copy(ref::DataRef{D}) where {D}
76+
if ref.freed
77+
throw(ArgumentError("Attempt to copy a freed reference."))
78+
end
79+
retain(ref.rc)
80+
return DataRef{D}(ref.rc, false)
81+
end
82+
83+
function unsafe_free!(ref::DataRef, args...)
84+
if ref.freed
85+
# multiple frees *of the same object* are allowed.
86+
# we should only ever call `release` once per object, though,
87+
# as multiple releases of the underlying data is not allowed.
88+
return
89+
end
90+
release(ref.rc, args...)
91+
ref.freed = true
92+
return
93+
end
94+
95+
496
# input/output
597

698
## serialization

0 commit comments

Comments
 (0)