1
1
# reference implementation on the CPU
2
-
3
- # note that most of the code in this file serves to define a functional array type,
4
- # the actual implementation of GPUArrays-interfaces is much more limited.
2
+ # This acts as a wrapper around KernelAbstractions's parallel CPU
3
+ # functionality. It is useful for testing GPUArrays (and other packages)
4
+ # when no GPU is present.
5
+ # This file follows conventions from AMDGPU.jl
5
6
6
7
module JLArrays
7
8
8
- export JLArray, JLVector, JLMatrix, jl
9
-
10
9
using GPUArrays
11
-
12
10
using Adapt
11
+ import KernelAbstractions
12
+ import KernelAbstractions: Adapt, StaticArrays, Backend, Kernel, StaticSize, DynamicSize, partition, blocks, workitems, launch_config
13
13
14
+ export JLArray, JLVector, JLMatrix, jl, JLBackend
14
15
15
- #
16
- # Device functionality
17
- #
18
-
19
- const MAXTHREADS = 256
20
-
21
-
22
- # # execution
23
-
24
- struct JLBackend <: AbstractGPUBackend end
25
-
26
- mutable struct JLKernelContext <: AbstractKernelContext
27
- blockdim:: Int
28
- griddim:: Int
29
- blockidx:: Int
30
- threadidx:: Int
31
-
32
- localmem_counter:: Int
33
- localmems:: Vector{Vector{Array}}
16
+ struct JLBackend <: KernelAbstractions.GPU
17
+ static:: Bool
18
+ JLBackend (;static:: Bool = false ) = new (static)
34
19
end
35
20
36
- function JLKernelContext (threads:: Int , blockdim:: Int )
37
- blockcount = prod (blockdim)
38
- lmems = [Vector {Array} () for i in 1 : blockcount]
39
- JLKernelContext (threads, blockdim, 1 , 1 , 0 , lmems)
40
- end
41
-
42
- function JLKernelContext (ctx:: JLKernelContext , threadidx:: Int )
43
- JLKernelContext (
44
- ctx. blockdim,
45
- ctx. griddim,
46
- ctx. blockidx,
47
- threadidx,
48
- 0 ,
49
- ctx. localmems
50
- )
51
- end
52
21
53
22
struct Adaptor end
54
23
jlconvert (arg) = adapt (Adaptor (), arg)
60
29
Base. getindex (r:: JlRefValue ) = r. x
61
30
Adapt. adapt_structure (to:: Adaptor , r:: Base.RefValue ) = JlRefValue (adapt (to, r[]))
62
31
63
- function GPUArrays. gpu_call (:: JLBackend , f, args, threads:: Int , blocks:: Int ;
64
- name:: Union{String,Nothing} )
65
- ctx = JLKernelContext (threads, blocks)
66
- device_args = jlconvert .(args)
67
- tasks = Array {Task} (undef, threads)
68
- for blockidx in 1 : blocks
69
- ctx. blockidx = blockidx
70
- for threadidx in 1 : threads
71
- thread_ctx = JLKernelContext (ctx, threadidx)
72
- tasks[threadidx] = @async f (thread_ctx, device_args... )
73
- # TODO : require 1.3 and use Base.Threads.@spawn for actual multithreading
74
- # (this would require a different synchronization mechanism)
75
- end
76
- for t in tasks
77
- fetch (t)
78
- end
32
+ mutable struct JLArray{T, N} <: AbstractGPUArray{T, N}
33
+ data:: DataRef{Vector{UInt8}}
34
+
35
+ offset:: Int # offset of the data in the buffer, in number of elements
36
+
37
+ dims:: Dims{N}
38
+
39
+ # allocating constructor
40
+ function JLArray {T,N} (:: UndefInitializer , dims:: Dims{N} ) where {T,N}
41
+ check_eltype (T)
42
+ maxsize = prod (dims) * sizeof (T)
43
+ data = Vector {UInt8} (undef, maxsize)
44
+ ref = DataRef (data)
45
+ obj = new {T,N} (ref, 0 , dims)
46
+ finalizer (unsafe_free!, obj)
79
47
end
80
- return
81
- end
82
48
49
+ # low-level constructor for wrapping existing data
50
+ function JLArray {T,N} (ref:: DataRef{Vector{UInt8}} , dims:: Dims{N} ;
51
+ offset:: Int = 0 ) where {T,N}
52
+ check_eltype (T)
53
+ obj = new {T,N} (ref, offset, dims)
54
+ finalizer (unsafe_free!, obj)
55
+ end
56
+ end
83
57
84
- # # executed on-device
58
+ Adapt. adapt_storage (:: JLBackend , a:: Array ) = Adapt. adapt (JLArrays. JLArray, a)
59
+ Adapt. adapt_storage (:: JLBackend , a:: JLArrays.JLArray ) = a
60
+ Adapt. adapt_storage (:: KernelAbstractions.CPU , a:: JLArrays.JLArray ) = convert (Array, a)
85
61
86
62
# array type
87
63
107
83
@inline Base. getindex (A:: JLDeviceArray , index:: Integer ) = getindex (typed_data (A), index)
108
84
@inline Base. setindex! (A:: JLDeviceArray , x, index:: Integer ) = setindex! (typed_data (A), x, index)
109
85
110
-
111
- # indexing
112
-
113
- for f in (:blockidx , :blockdim , :threadidx , :griddim )
114
- @eval GPUArrays.$ f (ctx:: JLKernelContext ) = ctx.$ f
115
- end
116
-
117
- # memory
118
-
119
- function GPUArrays. LocalMemory (ctx:: JLKernelContext , :: Type{T} , :: Val{dims} , :: Val{id} ) where {T, dims, id}
120
- ctx. localmem_counter += 1
121
- lmems = ctx. localmems[blockidx (ctx)]
122
-
123
- # first invocation in block
124
- data = if length (lmems) < ctx. localmem_counter
125
- lmem = fill (zero (T), dims)
126
- push! (lmems, lmem)
127
- lmem
128
- else
129
- lmems[ctx. localmem_counter]
130
- end
131
-
132
- N = length (dims)
133
- JLDeviceArray {T,N} (data, tuple (dims... ))
134
- end
135
-
136
- # synchronization
137
-
138
- @inline function GPUArrays. synchronize_threads (:: JLKernelContext )
139
- # All threads are getting started asynchronously, so a yield will yield to the next
140
- # execution of the same function, which should call yield at the exact same point in the
141
- # program, leading to a chain of yields effectively syncing the tasks (threads).
142
- yield ()
143
- return
144
- end
145
-
146
-
147
86
#
148
87
# Host abstractions
149
88
#
@@ -157,32 +96,6 @@ function check_eltype(T)
157
96
end
158
97
end
159
98
160
- mutable struct JLArray{T, N} <: AbstractGPUArray{T, N}
161
- data:: DataRef{Vector{UInt8}}
162
-
163
- offset:: Int # offset of the data in the buffer, in number of elements
164
-
165
- dims:: Dims{N}
166
-
167
- # allocating constructor
168
- function JLArray {T,N} (:: UndefInitializer , dims:: Dims{N} ) where {T,N}
169
- check_eltype (T)
170
- maxsize = prod (dims) * sizeof (T)
171
- data = Vector {UInt8} (undef, maxsize)
172
- ref = DataRef (data)
173
- obj = new {T,N} (ref, 0 , dims)
174
- finalizer (unsafe_free!, obj)
175
- end
176
-
177
- # low-level constructor for wrapping existing data
178
- function JLArray {T,N} (ref:: DataRef{Vector{UInt8}} , dims:: Dims{N} ;
179
- offset:: Int = 0 ) where {T,N}
180
- check_eltype (T)
181
- obj = new {T,N} (ref, offset, dims)
182
- finalizer (unsafe_free!, obj)
183
- end
184
- end
185
-
186
99
unsafe_free! (a:: JLArray ) = GPUArrays. unsafe_free! (a. data)
187
100
188
101
# conversion of untyped data to a typed Array
392
305
393
306
# # GPUArrays interfaces
394
307
395
- GPUArrays. backend (:: Type{<:JLArray} ) = JLBackend ()
396
-
397
308
Adapt. adapt_storage (:: Adaptor , x:: JLArray{T,N} ) where {T,N} =
398
309
JLDeviceArray {T,N} (x. data[], x. offset, x. dims)
399
310
@@ -406,4 +317,47 @@ function GPUArrays.mapreducedim!(f, op, R::AnyJLArray, A::Union{AbstractArray,Br
406
317
R
407
318
end
408
319
320
+ # # KernelAbstractions interface
321
+
322
+ KernelAbstractions. get_backend (a:: JLA ) where JLA <: JLArray = JLBackend ()
323
+
324
+ function KernelAbstractions. mkcontext (kernel:: Kernel{JLBackend} , I, _ndrange, iterspace, :: Dynamic ) where Dynamic
325
+ return KernelAbstractions. CompilerMetadata {KernelAbstractions.ndrange(kernel), Dynamic} (I, _ndrange, iterspace)
326
+ end
327
+
328
+ KernelAbstractions. allocate (:: JLBackend , :: Type{T} , dims:: Tuple ) where T = JLArray {T} (undef, dims)
329
+
330
+ @inline function launch_config (kernel:: Kernel{JLBackend} , ndrange, workgroupsize)
331
+ if ndrange isa Integer
332
+ ndrange = (ndrange,)
333
+ end
334
+ if workgroupsize isa Integer
335
+ workgroupsize = (workgroupsize, )
336
+ end
337
+
338
+ if KernelAbstractions. workgroupsize (kernel) <: DynamicSize && workgroupsize === nothing
339
+ workgroupsize = (1024 ,) # Vectorization, 4x unrolling, minimal grain size
340
+ end
341
+ iterspace, dynamic = partition (kernel, ndrange, workgroupsize)
342
+ # partition checked that the ndrange's agreed
343
+ if KernelAbstractions. ndrange (kernel) <: StaticSize
344
+ ndrange = nothing
345
+ end
346
+
347
+ return ndrange, workgroupsize, iterspace, dynamic
348
+ end
349
+
350
+ KernelAbstractions. isgpu (b:: JLBackend ) = false
351
+
352
+ function convert_to_cpu (obj:: Kernel{JLBackend, W, N, F} ) where {W, N, F}
353
+ return Kernel {typeof(KernelAbstractions.CPU(; static = obj.backend.static)), W, N, F} (KernelAbstractions. CPU (; static = obj. backend. static), obj. f)
354
+ end
355
+
356
+ function (obj:: Kernel{JLBackend} )(args... ; ndrange= nothing , workgroupsize= nothing )
357
+ device_args = jlconvert .(args)
358
+ new_obj = convert_to_cpu (obj)
359
+ new_obj (device_args... ; ndrange, workgroupsize)
360
+
361
+ end
362
+
409
363
end
0 commit comments