1
1
# reference implementation on the CPU
2
-
3
- # note that most of the code in this file serves to define a functional array type,
4
- # the actual implementation of GPUArrays-interfaces is much more limited.
2
+ # This acts as a wrapper around KernelAbstractions's parallel CPU
3
+ # functionality. It is useful for testing GPUArrays (and other packages)
4
+ # when no GPU is present.
5
+ # This file follows conventions from AMDGPU.jl
5
6
6
7
module JLArrays
7
8
8
- export JLArray, JLVector, JLMatrix, jl
9
-
10
9
using GPUArrays
11
-
12
10
using Adapt
11
+ import KernelAbstractions
12
+ import KernelAbstractions: Adapt, StaticArrays, Backend, Kernel, StaticSize, DynamicSize, partition, blocks, workitems, launch_config
13
13
14
+ export JLArray, JLVector, JLMatrix, jl, JLBackend
14
15
15
16
#
16
17
# Device functionality
17
18
#
18
19
19
20
const MAXTHREADS = 256
20
21
21
-
22
- # # execution
23
-
24
- struct JLBackend <: AbstractGPUBackend end
25
-
26
- mutable struct JLKernelContext <: AbstractKernelContext
27
- blockdim:: Int
28
- griddim:: Int
29
- blockidx:: Int
30
- threadidx:: Int
31
-
32
- localmem_counter:: Int
33
- localmems:: Vector{Vector{Array}}
34
- end
35
-
36
- function JLKernelContext (threads:: Int , blockdim:: Int )
37
- blockcount = prod (blockdim)
38
- lmems = [Vector {Array} () for i in 1 : blockcount]
39
- JLKernelContext (threads, blockdim, 1 , 1 , 0 , lmems)
22
+ struct JLBackend <: KernelAbstractions.GPU
23
+ static:: Bool
24
+ JLBackend (;static:: Bool = false ) = new (static)
40
25
end
41
26
42
- function JLKernelContext (ctx:: JLKernelContext , threadidx:: Int )
43
- JLKernelContext (
44
- ctx. blockdim,
45
- ctx. griddim,
46
- ctx. blockidx,
47
- threadidx,
48
- 0 ,
49
- ctx. localmems
50
- )
51
- end
52
27
53
28
struct Adaptor end
54
29
jlconvert (arg) = adapt (Adaptor (), arg)
60
35
Base. getindex (r:: JlRefValue ) = r. x
61
36
Adapt. adapt_structure (to:: Adaptor , r:: Base.RefValue ) = JlRefValue (adapt (to, r[]))
62
37
63
- function GPUArrays. gpu_call (:: JLBackend , f, args, threads:: Int , blocks:: Int ;
64
- name:: Union{String,Nothing} )
65
- ctx = JLKernelContext (threads, blocks)
66
- device_args = jlconvert .(args)
67
- tasks = Array {Task} (undef, threads)
68
- for blockidx in 1 : blocks
69
- ctx. blockidx = blockidx
70
- for threadidx in 1 : threads
71
- thread_ctx = JLKernelContext (ctx, threadidx)
72
- tasks[threadidx] = @async f (thread_ctx, device_args... )
73
- # TODO : require 1.3 and use Base.Threads.@spawn for actual multithreading
74
- # (this would require a different synchronization mechanism)
75
- end
76
- for t in tasks
77
- fetch (t)
78
- end
38
+ mutable struct JLArray{T, N} <: AbstractGPUArray{T, N}
39
+ data:: DataRef{Vector{UInt8}}
40
+
41
+ offset:: Int # offset of the data in the buffer, in number of elements
42
+
43
+ dims:: Dims{N}
44
+
45
+ # allocating constructor
46
+ function JLArray {T,N} (:: UndefInitializer , dims:: Dims{N} ) where {T,N}
47
+ check_eltype (T)
48
+ maxsize = prod (dims) * sizeof (T)
49
+ data = Vector {UInt8} (undef, maxsize)
50
+ ref = DataRef (data)
51
+ obj = new {T,N} (ref, 0 , dims)
52
+ finalizer (unsafe_free!, obj)
79
53
end
80
- return
81
- end
82
54
55
+ # low-level constructor for wrapping existing data
56
+ function JLArray {T,N} (ref:: DataRef{Vector{UInt8}} , dims:: Dims{N} ;
57
+ offset:: Int = 0 ) where {T,N}
58
+ check_eltype (T)
59
+ obj = new {T,N} (ref, offset, dims)
60
+ finalizer (unsafe_free!, obj)
61
+ end
62
+ end
83
63
84
- # # executed on-device
64
+ Adapt. adapt_storage (:: JLBackend , a:: Array ) = Adapt. adapt (JLArrays. JLArray, a)
65
+ Adapt. adapt_storage (:: JLBackend , a:: JLArrays.JLArray ) = a
66
+ Adapt. adapt_storage (:: KernelAbstractions.CPU , a:: JLArrays.JLArray ) = convert (Array, a)
85
67
86
68
# array type
87
69
107
89
@inline Base. getindex (A:: JLDeviceArray , index:: Integer ) = getindex (typed_data (A), index)
108
90
@inline Base. setindex! (A:: JLDeviceArray , x, index:: Integer ) = setindex! (typed_data (A), x, index)
109
91
110
-
111
- # indexing
112
-
113
- for f in (:blockidx , :blockdim , :threadidx , :griddim )
114
- @eval GPUArrays.$ f (ctx:: JLKernelContext ) = ctx.$ f
115
- end
116
-
117
- # memory
118
-
119
- function GPUArrays. LocalMemory (ctx:: JLKernelContext , :: Type{T} , :: Val{dims} , :: Val{id} ) where {T, dims, id}
120
- ctx. localmem_counter += 1
121
- lmems = ctx. localmems[blockidx (ctx)]
122
-
123
- # first invocation in block
124
- data = if length (lmems) < ctx. localmem_counter
125
- lmem = fill (zero (T), dims)
126
- push! (lmems, lmem)
127
- lmem
128
- else
129
- lmems[ctx. localmem_counter]
130
- end
131
-
132
- N = length (dims)
133
- JLDeviceArray {T,N} (data, tuple (dims... ))
134
- end
135
-
136
- # synchronization
137
-
138
- @inline function GPUArrays. synchronize_threads (:: JLKernelContext )
139
- # All threads are getting started asynchronously, so a yield will yield to the next
140
- # execution of the same function, which should call yield at the exact same point in the
141
- # program, leading to a chain of yields effectively syncing the tasks (threads).
142
- yield ()
143
- return
144
- end
145
-
146
-
147
92
#
148
93
# Host abstractions
149
94
#
@@ -157,32 +102,6 @@ function check_eltype(T)
157
102
end
158
103
end
159
104
160
- mutable struct JLArray{T, N} <: AbstractGPUArray{T, N}
161
- data:: DataRef{Vector{UInt8}}
162
-
163
- offset:: Int # offset of the data in the buffer, in number of elements
164
-
165
- dims:: Dims{N}
166
-
167
- # allocating constructor
168
- function JLArray {T,N} (:: UndefInitializer , dims:: Dims{N} ) where {T,N}
169
- check_eltype (T)
170
- maxsize = prod (dims) * sizeof (T)
171
- data = Vector {UInt8} (undef, maxsize)
172
- ref = DataRef (data)
173
- obj = new {T,N} (ref, 0 , dims)
174
- finalizer (unsafe_free!, obj)
175
- end
176
-
177
- # low-level constructor for wrapping existing data
178
- function JLArray {T,N} (ref:: DataRef{Vector{UInt8}} , dims:: Dims{N} ;
179
- offset:: Int = 0 ) where {T,N}
180
- check_eltype (T)
181
- obj = new {T,N} (ref, offset, dims)
182
- finalizer (unsafe_free!, obj)
183
- end
184
- end
185
-
186
105
unsafe_free! (a:: JLArray ) = GPUArrays. unsafe_free! (a. data)
187
106
188
107
# conversion of untyped data to a typed Array
392
311
393
312
# # GPUArrays interfaces
394
313
395
- GPUArrays. backend (:: Type{<:JLArray} ) = JLBackend ()
396
-
397
314
Adapt. adapt_storage (:: Adaptor , x:: JLArray{T,N} ) where {T,N} =
398
315
JLDeviceArray {T,N} (x. data[], x. offset, x. dims)
399
316
@@ -406,4 +323,47 @@ function GPUArrays.mapreducedim!(f, op, R::AnyJLArray, A::Union{AbstractArray,Br
406
323
R
407
324
end
408
325
326
+ # # KernelAbstractions interface
327
+
328
+ KernelAbstractions. get_backend (a:: JLA ) where JLA <: JLArray = JLBackend ()
329
+
330
+ function KernelAbstractions. mkcontext (kernel:: Kernel{JLBackend} , I, _ndrange, iterspace, :: Dynamic ) where Dynamic
331
+ return KernelAbstractions. CompilerMetadata {KernelAbstractions.ndrange(kernel), Dynamic} (I, _ndrange, iterspace)
332
+ end
333
+
334
+ KernelAbstractions. allocate (:: JLBackend , :: Type{T} , dims:: Tuple ) where T = JLArray {T} (undef, dims)
335
+
336
+ @inline function launch_config (kernel:: Kernel{JLBackend} , ndrange, workgroupsize)
337
+ if ndrange isa Integer
338
+ ndrange = (ndrange,)
339
+ end
340
+ if workgroupsize isa Integer
341
+ workgroupsize = (workgroupsize, )
342
+ end
343
+
344
+ if KernelAbstractions. workgroupsize (kernel) <: DynamicSize && workgroupsize === nothing
345
+ workgroupsize = (1024 ,) # Vectorization, 4x unrolling, minimal grain size
346
+ end
347
+ iterspace, dynamic = partition (kernel, ndrange, workgroupsize)
348
+ # partition checked that the ndrange's agreed
349
+ if KernelAbstractions. ndrange (kernel) <: StaticSize
350
+ ndrange = nothing
351
+ end
352
+
353
+ return ndrange, workgroupsize, iterspace, dynamic
354
+ end
355
+
356
+ KernelAbstractions. isgpu (b:: JLBackend ) = false
357
+
358
+ function convert_to_cpu (obj:: Kernel{JLBackend, W, N, F} ) where {W, N, F}
359
+ return Kernel {typeof(KernelAbstractions.CPU(; static = obj.backend.static)), W, N, F} (KernelAbstractions. CPU (; static = obj. backend. static), obj. f)
360
+ end
361
+
362
+ function (obj:: Kernel{JLBackend} )(args... ; ndrange= nothing , workgroupsize= nothing )
363
+ device_args = jlconvert .(args)
364
+ new_obj = convert_to_cpu (obj)
365
+ new_obj (device_args... ; ndrange, workgroupsize)
366
+
367
+ end
368
+
409
369
end
0 commit comments