Skip to content

Commit 0800eee

Browse files
vchuravyleios
authored andcommitted
Transition GPUArrays to KernelAbstractions
Co-authored-by: Valentin Churavy <vchuravy@users.noreply.github.com>
1 parent 7970d56 commit 0800eee

25 files changed

+324
-614
lines changed

Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ version = "10.0.2"
55
[deps]
66
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
77
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
8+
KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
89
LLVM = "929cbde3-209d-540e-8aea-75f648917ca0"
910
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1011
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"

docs/src/interface.md

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,6 @@ all, you need to provide a type that represents your execution back-end and a wa
1313
kernels:
1414

1515
```@docs
16-
GPUArrays.AbstractGPUBackend
17-
GPUArrays.AbstractKernelContext
1816
GPUArrays.gpu_call
1917
GPUArrays.thread_block_heuristic
2018
```

lib/GPUArraysCore/src/GPUArraysCore.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -222,10 +222,10 @@ end
222222
223223
Gets the GPUArrays back-end responsible for managing arrays of type `T`.
224224
"""
225-
backend(::Type) = error("This object is not a GPU array") # COV_EXCL_LINE
226-
backend(x) = backend(typeof(x))
225+
get_backend(::Type) = error("This object is not a GPU array") # COV_EXCL_LINE
226+
get_backend(x) = get_backend(typeof(x))
227227

228228
# WrappedArray from Adapt for Base wrappers.
229-
backend(::Type{WA}) where WA<:WrappedArray = backend(unwrap_type(WA))
229+
get_backend(::Type{WA}) where WA<:WrappedArray = backend(unwrap_type(WA))
230230

231231
end # module GPUArraysCore

lib/JLArrays/Project.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,11 @@ version = "0.1.4"
66
[deps]
77
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
88
GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7"
9+
KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
910
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1011

1112
[compat]
1213
Adapt = "2.0, 3.0, 4.0"
1314
GPUArrays = "10"
14-
julia = "1.8"
1515
Random = "1"
16+
julia = "1.8"

lib/JLArrays/src/JLArrays.jl

Lines changed: 79 additions & 119 deletions
Original file line numberDiff line numberDiff line change
@@ -1,54 +1,29 @@
11
# reference implementation on the CPU
2-
3-
# note that most of the code in this file serves to define a functional array type,
4-
# the actual implementation of GPUArrays-interfaces is much more limited.
2+
# This acts as a wrapper around KernelAbstractions's parallel CPU
3+
# functionality. It is useful for testing GPUArrays (and other packages)
4+
# when no GPU is present.
5+
# This file follows conventions from AMDGPU.jl
56

67
module JLArrays
78

8-
export JLArray, JLVector, JLMatrix, jl
9-
109
using GPUArrays
11-
1210
using Adapt
11+
import KernelAbstractions
12+
import KernelAbstractions: Adapt, StaticArrays, Backend, Kernel, StaticSize, DynamicSize, partition, blocks, workitems, launch_config
1313

14+
export JLArray, JLVector, JLMatrix, jl, JLBackend
1415

1516
#
1617
# Device functionality
1718
#
1819

1920
const MAXTHREADS = 256
2021

21-
22-
## execution
23-
24-
struct JLBackend <: AbstractGPUBackend end
25-
26-
mutable struct JLKernelContext <: AbstractKernelContext
27-
blockdim::Int
28-
griddim::Int
29-
blockidx::Int
30-
threadidx::Int
31-
32-
localmem_counter::Int
33-
localmems::Vector{Vector{Array}}
34-
end
35-
36-
function JLKernelContext(threads::Int, blockdim::Int)
37-
blockcount = prod(blockdim)
38-
lmems = [Vector{Array}() for i in 1:blockcount]
39-
JLKernelContext(threads, blockdim, 1, 1, 0, lmems)
22+
struct JLBackend <: KernelAbstractions.GPU
23+
static::Bool
24+
JLBackend(;static::Bool=false) = new(static)
4025
end
4126

42-
function JLKernelContext(ctx::JLKernelContext, threadidx::Int)
43-
JLKernelContext(
44-
ctx.blockdim,
45-
ctx.griddim,
46-
ctx.blockidx,
47-
threadidx,
48-
0,
49-
ctx.localmems
50-
)
51-
end
5227

5328
struct Adaptor end
5429
jlconvert(arg) = adapt(Adaptor(), arg)
@@ -60,28 +35,35 @@ end
6035
Base.getindex(r::JlRefValue) = r.x
6136
Adapt.adapt_structure(to::Adaptor, r::Base.RefValue) = JlRefValue(adapt(to, r[]))
6237

63-
function GPUArrays.gpu_call(::JLBackend, f, args, threads::Int, blocks::Int;
64-
name::Union{String,Nothing})
65-
ctx = JLKernelContext(threads, blocks)
66-
device_args = jlconvert.(args)
67-
tasks = Array{Task}(undef, threads)
68-
for blockidx in 1:blocks
69-
ctx.blockidx = blockidx
70-
for threadidx in 1:threads
71-
thread_ctx = JLKernelContext(ctx, threadidx)
72-
tasks[threadidx] = @async f(thread_ctx, device_args...)
73-
# TODO: require 1.3 and use Base.Threads.@spawn for actual multithreading
74-
# (this would require a different synchronization mechanism)
75-
end
76-
for t in tasks
77-
fetch(t)
78-
end
38+
mutable struct JLArray{T, N} <: AbstractGPUArray{T, N}
39+
data::DataRef{Vector{UInt8}}
40+
41+
offset::Int # offset of the data in the buffer, in number of elements
42+
43+
dims::Dims{N}
44+
45+
# allocating constructor
46+
function JLArray{T,N}(::UndefInitializer, dims::Dims{N}) where {T,N}
47+
check_eltype(T)
48+
maxsize = prod(dims) * sizeof(T)
49+
data = Vector{UInt8}(undef, maxsize)
50+
ref = DataRef(data)
51+
obj = new{T,N}(ref, 0, dims)
52+
finalizer(unsafe_free!, obj)
7953
end
80-
return
81-
end
8254

55+
# low-level constructor for wrapping existing data
56+
function JLArray{T,N}(ref::DataRef{Vector{UInt8}}, dims::Dims{N};
57+
offset::Int=0) where {T,N}
58+
check_eltype(T)
59+
obj = new{T,N}(ref, offset, dims)
60+
finalizer(unsafe_free!, obj)
61+
end
62+
end
8363

84-
## executed on-device
64+
Adapt.adapt_storage(::JLBackend, a::Array) = Adapt.adapt(JLArrays.JLArray, a)
65+
Adapt.adapt_storage(::JLBackend, a::JLArrays.JLArray) = a
66+
Adapt.adapt_storage(::KernelAbstractions.CPU, a::JLArrays.JLArray) = convert(Array, a)
8567

8668
# array type
8769

@@ -107,43 +89,6 @@ end
10789
@inline Base.getindex(A::JLDeviceArray, index::Integer) = getindex(typed_data(A), index)
10890
@inline Base.setindex!(A::JLDeviceArray, x, index::Integer) = setindex!(typed_data(A), x, index)
10991

110-
111-
# indexing
112-
113-
for f in (:blockidx, :blockdim, :threadidx, :griddim)
114-
@eval GPUArrays.$f(ctx::JLKernelContext) = ctx.$f
115-
end
116-
117-
# memory
118-
119-
function GPUArrays.LocalMemory(ctx::JLKernelContext, ::Type{T}, ::Val{dims}, ::Val{id}) where {T, dims, id}
120-
ctx.localmem_counter += 1
121-
lmems = ctx.localmems[blockidx(ctx)]
122-
123-
# first invocation in block
124-
data = if length(lmems) < ctx.localmem_counter
125-
lmem = fill(zero(T), dims)
126-
push!(lmems, lmem)
127-
lmem
128-
else
129-
lmems[ctx.localmem_counter]
130-
end
131-
132-
N = length(dims)
133-
JLDeviceArray{T,N}(data, tuple(dims...))
134-
end
135-
136-
# synchronization
137-
138-
@inline function GPUArrays.synchronize_threads(::JLKernelContext)
139-
# All threads are getting started asynchronously, so a yield will yield to the next
140-
# execution of the same function, which should call yield at the exact same point in the
141-
# program, leading to a chain of yields effectively syncing the tasks (threads).
142-
yield()
143-
return
144-
end
145-
146-
14792
#
14893
# Host abstractions
14994
#
@@ -157,32 +102,6 @@ function check_eltype(T)
157102
end
158103
end
159104

160-
mutable struct JLArray{T, N} <: AbstractGPUArray{T, N}
161-
data::DataRef{Vector{UInt8}}
162-
163-
offset::Int # offset of the data in the buffer, in number of elements
164-
165-
dims::Dims{N}
166-
167-
# allocating constructor
168-
function JLArray{T,N}(::UndefInitializer, dims::Dims{N}) where {T,N}
169-
check_eltype(T)
170-
maxsize = prod(dims) * sizeof(T)
171-
data = Vector{UInt8}(undef, maxsize)
172-
ref = DataRef(data)
173-
obj = new{T,N}(ref, 0, dims)
174-
finalizer(unsafe_free!, obj)
175-
end
176-
177-
# low-level constructor for wrapping existing data
178-
function JLArray{T,N}(ref::DataRef{Vector{UInt8}}, dims::Dims{N};
179-
offset::Int=0) where {T,N}
180-
check_eltype(T)
181-
obj = new{T,N}(ref, offset, dims)
182-
finalizer(unsafe_free!, obj)
183-
end
184-
end
185-
186105
unsafe_free!(a::JLArray) = GPUArrays.unsafe_free!(a.data)
187106

188107
# conversion of untyped data to a typed Array
@@ -392,8 +311,6 @@ end
392311

393312
## GPUArrays interfaces
394313

395-
GPUArrays.backend(::Type{<:JLArray}) = JLBackend()
396-
397314
Adapt.adapt_storage(::Adaptor, x::JLArray{T,N}) where {T,N} =
398315
JLDeviceArray{T,N}(x.data[], x.offset, x.dims)
399316

@@ -406,4 +323,47 @@ function GPUArrays.mapreducedim!(f, op, R::AnyJLArray, A::Union{AbstractArray,Br
406323
R
407324
end
408325

326+
## KernelAbstractions interface
327+
328+
KernelAbstractions.get_backend(a::JLA) where JLA <: JLArray = JLBackend()
329+
330+
function KernelAbstractions.mkcontext(kernel::Kernel{JLBackend}, I, _ndrange, iterspace, ::Dynamic) where Dynamic
331+
return KernelAbstractions.CompilerMetadata{KernelAbstractions.ndrange(kernel), Dynamic}(I, _ndrange, iterspace)
332+
end
333+
334+
KernelAbstractions.allocate(::JLBackend, ::Type{T}, dims::Tuple) where T = JLArray{T}(undef, dims)
335+
336+
@inline function launch_config(kernel::Kernel{JLBackend}, ndrange, workgroupsize)
337+
if ndrange isa Integer
338+
ndrange = (ndrange,)
339+
end
340+
if workgroupsize isa Integer
341+
workgroupsize = (workgroupsize, )
342+
end
343+
344+
if KernelAbstractions.workgroupsize(kernel) <: DynamicSize && workgroupsize === nothing
345+
workgroupsize = (1024,) # Vectorization, 4x unrolling, minimal grain size
346+
end
347+
iterspace, dynamic = partition(kernel, ndrange, workgroupsize)
348+
# partition checked that the ndrange's agreed
349+
if KernelAbstractions.ndrange(kernel) <: StaticSize
350+
ndrange = nothing
351+
end
352+
353+
return ndrange, workgroupsize, iterspace, dynamic
354+
end
355+
356+
KernelAbstractions.isgpu(b::JLBackend) = false
357+
358+
function convert_to_cpu(obj::Kernel{JLBackend, W, N, F}) where {W, N, F}
359+
return Kernel{typeof(KernelAbstractions.CPU(; static = obj.backend.static)), W, N, F}(KernelAbstractions.CPU(; static = obj.backend.static), obj.f)
360+
end
361+
362+
function (obj::Kernel{JLBackend})(args...; ndrange=nothing, workgroupsize=nothing)
363+
device_args = jlconvert.(args)
364+
new_obj = convert_to_cpu(obj)
365+
new_obj(device_args...; ndrange, workgroupsize)
366+
367+
end
368+
409369
end

src/GPUArrays.jl

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
module GPUArrays
22

3+
using KernelAbstractions
34
using Serialization
45
using Random
56
using LinearAlgebra
@@ -14,14 +15,11 @@ using LLVM.Interop
1415
using Reexport
1516
@reexport using GPUArraysCore
1617

17-
# device functionality
18-
include("device/execution.jl")
1918
## executed on-device
19+
include("device/execution.jl")
2020
include("device/abstractarray.jl")
21-
include("device/indexing.jl")
22-
include("device/memory.jl")
23-
include("device/synchronization.jl")
2421

22+
using KernelAbstractions
2523
# host abstractions
2624
include("host/abstractarray.jl")
2725
include("host/construction.jl")

0 commit comments

Comments
 (0)