5
5
6
6
module JLArrays
7
7
8
- export JLArray, jl
8
+ export JLArray, JLVector, JLMatrix, jl
9
9
10
10
using GPUArrays
11
11
86
86
# array type
87
87
88
88
struct JLDeviceArray{T, N} <: AbstractDeviceArray{T, N}
89
- data:: Array{T, N}
89
+ data:: Vector{UInt8}
90
+ offset:: Int
90
91
dims:: Dims{N}
91
-
92
- function JLDeviceArray {T,N} (data:: Array{T, N} , dims:: Dims{N} ) where {T,N}
93
- new (data, dims)
94
- end
95
92
end
96
93
94
+ Base. elsize (:: Type{<:JLDeviceArray{T}} ) where {T} = sizeof (T)
95
+
97
96
Base. size (x:: JLDeviceArray ) = x. dims
97
+ Base. sizeof (x:: JLDeviceArray ) = Base. elsize (x) * length (x)
98
+
99
+ Base. unsafe_convert (:: Type{Ptr{T}} , x:: JLDeviceArray{T} ) where {T} =
100
+ Base. unsafe_convert (Ptr{T}, x. data) + x. offset* Base. elsize (x)
101
+
102
+ # conversion of untyped data to a typed Array
103
+ function typed_data (x:: JLDeviceArray{T} ) where {T}
104
+ unsafe_wrap (Array, pointer (x), x. dims)
105
+ end
106
+
107
+ @inline Base. getindex (A:: JLDeviceArray , index:: Integer ) = getindex (typed_data (A), index)
108
+ @inline Base. setindex! (A:: JLDeviceArray , x, index:: Integer ) = setindex! (typed_data (A), x, index)
98
109
99
- @inline Base. getindex (A:: JLDeviceArray , index:: Integer ) = getindex (A. data, index)
100
- @inline Base. setindex! (A:: JLDeviceArray , x, index:: Integer ) = setindex! (A. data, x, index)
101
110
102
111
# indexing
103
112
@@ -139,23 +148,60 @@ end
139
148
# Host abstractions
140
149
#
141
150
142
- struct JLArray{T, N} <: AbstractGPUArray{T, N}
143
- data:: Array{T, N}
151
+ function check_eltype (T)
152
+ if ! Base. allocatedinline (T)
153
+ explanation = explain_allocatedinline (T)
154
+ error ("""
155
+ JLArray only supports element types that are allocated inline.
156
+ $explanation """ )
157
+ end
158
+ end
159
+
160
+ mutable struct JLArray{T, N} <: AbstractGPUArray{T, N}
161
+ data:: DataRef{Vector{UInt8}}
162
+
163
+ offset:: Int # offset of the data in the buffer, in number of elements
164
+
144
165
dims:: Dims{N}
145
166
146
- function JLArray {T,N} (data:: Array{T, N} , dims:: Dims{N} ) where {T,N}
147
- isbitstype (T) || error (" JLArray only supports bits types" )
148
- # when supporting isbits-union types, use `Base.allocatedinline` here.
149
- new (data, dims)
167
+ # allocating constructor
168
+ function JLArray {T,N} (:: UndefInitializer , dims:: Dims{N} ) where {T,N}
169
+ check_eltype (T)
170
+ maxsize = prod (dims) * sizeof (T)
171
+ data = Vector {UInt8} (undef, maxsize)
172
+ ref = DataRef (data)
173
+ obj = new {T,N} (ref, 0 , dims)
174
+ finalizer (unsafe_free!, obj)
150
175
end
176
+
177
+ # low-level constructor for wrapping existing data
178
+ function JLArray {T,N} (ref:: DataRef{Vector{UInt8}} , dims:: Dims{N} ;
179
+ offset:: Int = 0 ) where {T,N}
180
+ check_eltype (T)
181
+ obj = new {T,N} (ref, offset, dims)
182
+ finalizer (unsafe_free!, obj)
183
+ end
184
+ end
185
+
186
+ unsafe_free! (a:: JLArray ) = GPUArrays. unsafe_free! (a. data)
187
+
188
+ # conversion of untyped data to a typed Array
189
+ function typed_data (x:: JLArray{T} ) where {T}
190
+ unsafe_wrap (Array, pointer (x), x. dims)
191
+ end
192
+
193
+ function GPUArrays. derive (:: Type{T} , N:: Int , a:: JLArray , dims:: Dims , offset:: Int ) where {T}
194
+ ref = copy (a. data)
195
+ offset = (a. offset * Base. elsize (a)) ÷ sizeof (T) + offset
196
+ JLArray {T,N} (ref, dims; offset)
151
197
end
152
198
153
199
154
- # # constructors
200
+ # # convenience constructors
155
201
156
- # type and dimensionality specified, accepting dims as tuples of Ints
157
- JLArray {T,N} ( :: UndefInitializer , dims :: Dims{N} ) where {T,N} =
158
- JLArray {T,N} ( Array {T, N} (undef, dims), dims)
202
+ const JLVector{T} = JLArray{T, 1 }
203
+ const JLMatrix{T} = JLArray {T,2 }
204
+ const JLVecOrMat{T} = Union{JLVector{T},JLMatrix{T}}
159
205
160
206
# type and dimensionality specified, accepting dims as series of Ints
161
207
JLArray {T,N} (:: UndefInitializer , dims:: Integer... ) where {T,N} = JLArray {T,N} (undef, dims)
@@ -172,7 +218,10 @@ Base.similar(a::JLArray{T,N}) where {T,N} = JLArray{T,N}(undef, size(a))
172
218
Base. similar (a:: JLArray{T} , dims:: Base.Dims{N} ) where {T,N} = JLArray {T,N} (undef, dims)
173
219
Base. similar (a:: JLArray , :: Type{T} , dims:: Base.Dims{N} ) where {T,N} = JLArray {T,N} (undef, dims)
174
220
175
- Base. copy (a:: JLArray{T,N} ) where {T,N} = JLArray {T,N} (copy (a. data), size (a))
221
+ function Base. copy (a:: JLArray{T,N} ) where {T,N}
222
+ b = similar (a)
223
+ @inbounds copyto! (b, a)
224
+ end
176
225
177
226
178
227
# # derived types
@@ -181,31 +230,26 @@ export DenseJLArray, DenseJLVector, DenseJLMatrix, DenseJLVecOrMat,
181
230
StridedJLArray, StridedJLVector, StridedJLMatrix, StridedJLVecOrMat,
182
231
AnyJLArray, AnyJLVector, AnyJLMatrix, AnyJLVecOrMat
183
232
184
- ContiguousSubJLArray{T,N,A<: JLArray } = Base. FastContiguousSubArray{T,N,A}
185
-
186
233
# dense arrays: stored contiguously in memory
187
- DenseReinterpretJLArray{T,N,A<: Union{JLArray,ContiguousSubJLArray} } =
188
- Base. ReinterpretArray{T,N,S,A} where S
189
- DenseReshapedJLArray{T,N,A<: Union{JLArray,ContiguousSubJLArray,DenseReinterpretJLArray} } =
190
- Base. ReshapedArray{T,N,A}
191
- DenseSubJLArray{T,N,A<: Union{JLArray,DenseReshapedJLArray,DenseReinterpretJLArray} } =
192
- Base. FastContiguousSubArray{T,N,A}
193
- DenseJLArray{T,N} = Union{JLArray{T,N}, DenseSubJLArray{T,N}, DenseReshapedJLArray{T,N},
194
- DenseReinterpretJLArray{T,N}}
234
+ DenseJLArray{T,N} = JLArray{T,N}
195
235
DenseJLVector{T} = DenseJLArray{T,1 }
196
236
DenseJLMatrix{T} = DenseJLArray{T,2 }
197
237
DenseJLVecOrMat{T} = Union{DenseJLVector{T}, DenseJLMatrix{T}}
198
238
199
239
# strided arrays
200
- StridedSubJLArray{T,N,A<: Union{JLArray,DenseReshapedJLArray,DenseReinterpretJLArray} ,
201
- I<: Tuple {Vararg{Union{Base. RangeIndex, Base. ReshapedUnitRange,
202
- Base. AbstractCartesianIndex}}}} = SubArray{T,N,A,I}
203
- StridedJLArray{T,N} = Union{JLArray{T,N}, StridedSubJLArray{T,N}, DenseReshapedJLArray{T,N},
204
- DenseReinterpretJLArray{T,N}}
240
+ StridedSubJLArray{T,N,I<: Tuple {Vararg{Union{Base. RangeIndex, Base. ReshapedUnitRange,
241
+ Base. AbstractCartesianIndex}}}} =
242
+ SubArray{T,N,<: JLArray ,I}
243
+ StridedJLArray{T,N} = Union{JLArray{T,N}, StridedSubJLArray{T,N}}
205
244
StridedJLVector{T} = StridedJLArray{T,1 }
206
245
StridedJLMatrix{T} = StridedJLArray{T,2 }
207
246
StridedJLVecOrMat{T} = Union{StridedJLVector{T}, StridedJLMatrix{T}}
208
247
248
+ Base. pointer (x:: StridedJLArray{T} ) where {T} = Base. unsafe_convert (Ptr{T}, x)
249
+ @inline function Base. pointer (x:: StridedJLArray{T} , i:: Integer ) where T
250
+ Base. unsafe_convert (Ptr{T}, x) + Base. _memory_offset (x, i)
251
+ end
252
+
209
253
# anything that's (secretly) backed by a JLArray
210
254
AnyJLArray{T,N} = Union{JLArray{T,N}, WrappedArray{T,N,JLArray,JLArray{T,N}}}
211
255
AnyJLVector{T} = AnyJLArray{T,1 }
@@ -221,13 +265,16 @@ Base.size(x::JLArray) = x.dims
221
265
Base. sizeof (x:: JLArray ) = Base. elsize (x) * length (x)
222
266
223
267
Base. unsafe_convert (:: Type{Ptr{T}} , x:: JLArray{T} ) where {T} =
224
- Base. unsafe_convert (Ptr{T}, x. data)
268
+ Base. unsafe_convert (Ptr{T}, x. data[]) + x . offset * Base . elsize (x )
225
269
226
270
227
271
# # interop with Julia arrays
228
272
229
- JLArray {T,N} (x:: AbstractArray{<:Any,N} ) where {T,N} =
230
- JLArray {T,N} (convert (Array{T}, x), size (x))
273
+ function JLArray {T,N} (xs:: AbstractArray{<:Any,N} ) where {T,N}
274
+ A = JLArray {T,N} (undef, size (xs))
275
+ copyto! (A, convert (Array{T}, xs))
276
+ return A
277
+ end
231
278
232
279
# underspecified constructors
233
280
JLArray {T} (xs:: AbstractArray{S,N} ) where {T,N,S} = JLArray {T,N} (xs)
@@ -345,14 +392,15 @@ end
345
392
GPUArrays. backend (:: Type{<:JLArray} ) = JLBackend ()
346
393
347
394
Adapt. adapt_storage (:: Adaptor , x:: JLArray{T,N} ) where {T,N} =
348
- JLDeviceArray {T,N} (x. data, x. dims)
395
+ JLDeviceArray {T,N} (x. data[], x . offset , x. dims)
349
396
350
397
function GPUArrays. mapreducedim! (f, op, R:: AnyJLArray , A:: Union{AbstractArray,Broadcast.Broadcasted} ;
351
398
init= nothing )
352
399
if init != = nothing
353
400
fill! (R, init)
354
401
end
355
- @allowscalar Base. reducedim! (op, R. data, map (f, A))
402
+ @allowscalar Base. reducedim! (op, typed_data (R), map (f, A))
403
+ R
356
404
end
357
405
358
406
end
0 commit comments