Skip to content

Commit 2769d6b

Browse files
maleadtleios
andauthored
Switch to KernelAbstractions.jl (#559)
Co-authored-by: James Schloss <jrs.schloss@gmail.com>
1 parent 8dfd805 commit 2769d6b

29 files changed

+289
-691
lines changed

.buildkite/pipeline.yml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,8 @@ steps:
4242
cuda: "*"
4343
if: build.message !~ /\[skip tests\]/
4444
timeout_in_minutes: 120
45+
soft_fail:
46+
- exit_status: 3
4547

4648
- label: "oneAPI.jl"
4749
plugins:
@@ -87,6 +89,8 @@ steps:
8789
intel: "*"
8890
if: build.message !~ /\[skip tests\]/
8991
timeout_in_minutes: 60
92+
soft_fail:
93+
- exit_status: 3
9094

9195
- label: "Metal.jl"
9296
plugins:
@@ -132,6 +136,8 @@ steps:
132136
arch: "aarch64"
133137
if: build.message !~ /\[skip tests\]/
134138
timeout_in_minutes: 60
139+
soft_fail:
140+
- exit_status: 3
135141

136142
env:
137143
SECRET_CODECOV_TOKEN: "GrevHmzmr2Vt6UK4wbbTTB1+kcMcIlF6nCXVCk3Z0plHDimpD6BwdN9T2A+5J9k3I2em0xXUqpt+2qUSqM8Bn5mNdpjR0TvxVY3oYXc+qzvBXmcZJpuCgJeoTP1P+kVFwszUn4na3fohNq9Qffp6tXMn/j8yJQKOiiC8mkD0aPEI0zISHuDaa/7j7JYf0vTrMRRZ9BMUQHmFuVaIQN8FLGG2BiE3236rj4eHh0lj2IfekCG3wd/LUzAsMx0MC3kIR8WzOWW2rf6xUMPkjm5+NuHwhAOcZc0+LRM7GYIwoW/nHAgyIqjvLiInNFmaJk+7V/GAKtd+gSAIzmyBUHAy6A==;U2FsdGVkX1+4ZljneQoaNE295nRIx8D6+WoFIgT6Pg2BXHaTyhTL4sxEcG0jX0e7oq68uvi4bK7x7YMS4L0Kew=="

.github/workflows/ci.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ jobs:
7575
TestEnv.activate()
7676
catch err
7777
@error "Could not install OpenCL.jl" exception=(err,catch_backtrace())
78-
exit(3)
78+
exit(0)
7979
finally
8080
Pkg.activate(package)
8181
end

Project.toml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
name = "GPUArrays"
22
uuid = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7"
3-
version = "10.3.1"
3+
version = "11.0.0"
44

55
[deps]
66
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
77
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
8+
KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
89
LLVM = "929cbde3-209d-540e-8aea-75f648917ca0"
910
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1011
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
@@ -15,7 +16,7 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
1516

1617
[compat]
1718
Adapt = "4.0"
18-
GPUArraysCore = "= 0.1.6"
19+
GPUArraysCore = "= 0.2.0"
1920
LLVM = "3.9, 4, 5, 6, 7, 8, 9"
2021
LinearAlgebra = "1"
2122
Printf = "1"

docs/src/index.md

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,9 @@ will get a lot of functionality for free. This will allow to have multiple GPUAr
99
implementation for different purposes, while maximizing the ability to share code.
1010

1111
**This package is not intended for end users!** Instead, you should use one of the packages
12-
that builds on GPUArrays.jl. There is currently only a single package that actively builds
13-
on these interfaces, namely [CuArrays.jl](https://github.com/JuliaGPU/CuArrays.jl).
12+
that builds on GPUArrays.jl such as [CUDA](https://github.com/JuliaGPU/CUDA.jl), [AMDGPU](https://github.com/JuliaGPU/AMDGPU.jl), [OneAPI](https://github.com/JuliaGPU/oneAPI.jl), or [Metal](https://github.com/JuliaGPU/Metal.jl).
1413

15-
In this documentation, you will find more information on the interface that you are expected
14+
This documentation is meant for users who might wish to implement a version of GPUArrays for another GPU backend and will cover the features you will need
1615
to implement, the functionality you gain by doing so, and the test suite that is available
1716
to verify your implementation. GPUArrays.jl also provides a reference implementation of
1817
these interfaces on the CPU: The `JLArray` array type uses Julia's parallel programming

docs/src/interface.md

Lines changed: 18 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1,53 +1,32 @@
11
# Interface
22

33
To extend the above functionality to a new array type, you should use the types and
4-
implement the interfaces listed on this page. GPUArrays is design around having two
5-
different array types to represent a GPU array: one that only ever lives on the host, and
4+
implement the interfaces listed on this page. GPUArrays is designed around having two
5+
different array types to represent a GPU array: one that exists only on the host, and
66
one that actually can be instantiated on the device (i.e. in kernels).
7+
Device functionality is then handled by [KernelAbstractions.jl](https://github.com/JuliaGPU/KernelAbstractions.jl).
78

9+
## Host abstractions
810

9-
## Device functionality
10-
11-
Several types and interfaces are related to the device and execution of code on it. First of
12-
all, you need to provide a type that represents your execution back-end and a way to call
13-
kernels:
11+
You should provide an array type that builds on the `AbstractGPUArray` supertype, such as:
1412

15-
```@docs
16-
GPUArrays.AbstractGPUBackend
17-
GPUArrays.AbstractKernelContext
18-
GPUArrays.gpu_call
19-
GPUArrays.thread_block_heuristic
2013
```
14+
mutable struct CustomArray{T, N} <: AbstractGPUArray{T, N}
15+
data::DataRef{Vector{UInt8}}
16+
offset::Int
17+
dims::Dims{N}
18+
...
19+
end
2120
22-
You then need to provide implementations of certain methods that will be executed on the
23-
device itself:
24-
25-
```@docs
26-
GPUArrays.AbstractDeviceArray
27-
GPUArrays.LocalMemory
28-
GPUArrays.synchronize_threads
29-
GPUArrays.blockidx
30-
GPUArrays.blockdim
31-
GPUArrays.threadidx
32-
GPUArrays.griddim
3321
```
3422

23+
This will allow your defined type (in this case `JLArray`) to use the GPUArrays interface where available.
24+
To be able to actually use the functionality that is defined for `AbstractGPUArray`s, you need to define the backend, like so:
3525

36-
## Host abstractions
37-
38-
You should provide an array type that builds on the `AbstractGPUArray` supertype:
39-
40-
```@docs
41-
AbstractGPUArray
4226
```
43-
44-
First of all, you should implement operations that are expected to be defined for any
45-
`AbstractArray` type. Refer to the Julia manual for more details, or look at the `JLArray`
46-
reference implementation.
47-
48-
To be able to actually use the functionality that is defined for `AbstractGPUArray`s, you
49-
should provide implementations of the following interfaces:
50-
51-
```@docs
52-
GPUArrays.backend
27+
import KernelAbstractions: Backend
28+
struct CustomBackend <: KernelAbstractions.GPU
29+
KernelAbstractions.get_backend(a::CA) where CA <: CustomArray = CustomBackend()
5330
```
31+
32+
There are numerous examples of potential interfaces for GPUArrays, such as with [JLArrays](https://github.com/JuliaGPU/GPUArrays.jl/blob/master/lib/JLArrays/src/JLArrays.jl), [CuArrays](https://github.com/JuliaGPU/CUDA.jl/blob/master/src/gpuarrays.jl), and [ROCArrays](https://github.com/JuliaGPU/AMDGPU.jl/blob/master/src/gpuarrays.jl).

lib/GPUArraysCore/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "GPUArraysCore"
22
uuid = "46192b85-c4d5-4398-a991-12ede77f4527"
33
authors = ["Tim Besard <tim.besard@gmail.com>"]
4-
version = "0.1.6"
4+
version = "0.2.0"
55

66
[deps]
77
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"

lib/GPUArraysCore/src/GPUArraysCore.jl

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -209,19 +209,4 @@ macro allowscalar(ex)
209209
end
210210
end
211211

212-
213-
## other
214-
215-
"""
216-
backend(x)
217-
backend(T::Type)
218-
219-
Gets the GPUArrays back-end responsible for managing arrays of type `T`.
220-
"""
221-
backend(::Type) = error("This object is not a GPU array") # COV_EXCL_LINE
222-
backend(x) = backend(typeof(x))
223-
224-
# WrappedArray from Adapt for Base wrappers.
225-
backend(::Type{WA}) where WA<:WrappedArray = backend(unwrap_type(WA))
226-
227212
end # module GPUArraysCore

lib/JLArrays/Project.toml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,11 @@ version = "0.1.6"
66
[deps]
77
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
88
GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7"
9+
KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
910
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1011

1112
[compat]
1213
Adapt = "2.0, 3.0, 4.0"
13-
GPUArrays = "10"
14-
julia = "1.8"
14+
GPUArrays = "11"
1515
Random = "1"
16+
julia = "1.8"

lib/JLArrays/src/JLArrays.jl

Lines changed: 57 additions & 93 deletions
Original file line numberDiff line numberDiff line change
@@ -1,53 +1,30 @@
11
# reference implementation on the CPU
2-
3-
# note that most of the code in this file serves to define a functional array type,
4-
# the actual implementation of GPUArrays-interfaces is much more limited.
2+
# This acts as a wrapper around KernelAbstractions's parallel CPU
3+
# functionality. It is useful for testing GPUArrays (and other packages)
4+
# when no GPU is present.
5+
# This file follows conventions from AMDGPU.jl
56

67
module JLArrays
78

8-
export JLArray, JLVector, JLMatrix, jl
9+
export JLArray, JLVector, JLMatrix, jl, JLBackend
910

1011
using GPUArrays
1112

1213
using Adapt
1314

15+
import KernelAbstractions
16+
import KernelAbstractions: Adapt, StaticArrays, Backend, Kernel, StaticSize, DynamicSize, partition, blocks, workitems, launch_config
17+
1418

1519
#
1620
# Device functionality
1721
#
1822

1923
const MAXTHREADS = 256
2024

21-
22-
## execution
23-
24-
struct JLBackend <: AbstractGPUBackend end
25-
26-
mutable struct JLKernelContext <: AbstractKernelContext
27-
blockdim::Int
28-
griddim::Int
29-
blockidx::Int
30-
threadidx::Int
31-
32-
localmem_counter::Int
33-
localmems::Vector{Vector{Array}}
34-
end
35-
36-
function JLKernelContext(threads::Int, blockdim::Int)
37-
blockcount = prod(blockdim)
38-
lmems = [Vector{Array}() for i in 1:blockcount]
39-
JLKernelContext(threads, blockdim, 1, 1, 0, lmems)
40-
end
41-
42-
function JLKernelContext(ctx::JLKernelContext, threadidx::Int)
43-
JLKernelContext(
44-
ctx.blockdim,
45-
ctx.griddim,
46-
ctx.blockidx,
47-
threadidx,
48-
0,
49-
ctx.localmems
50-
)
25+
struct JLBackend <: KernelAbstractions.GPU
26+
static::Bool
27+
JLBackend(;static::Bool=false) = new(static)
5128
end
5229

5330
struct Adaptor end
@@ -60,27 +37,6 @@ end
6037
Base.getindex(r::JlRefValue) = r.x
6138
Adapt.adapt_structure(to::Adaptor, r::Base.RefValue) = JlRefValue(adapt(to, r[]))
6239

63-
function GPUArrays.gpu_call(::JLBackend, f, args, threads::Int, blocks::Int;
64-
name::Union{String,Nothing})
65-
ctx = JLKernelContext(threads, blocks)
66-
device_args = jlconvert.(args)
67-
tasks = Array{Task}(undef, threads)
68-
for blockidx in 1:blocks
69-
ctx.blockidx = blockidx
70-
for threadidx in 1:threads
71-
thread_ctx = JLKernelContext(ctx, threadidx)
72-
tasks[threadidx] = @async f(thread_ctx, device_args...)
73-
# TODO: require 1.3 and use Base.Threads.@spawn for actual multithreading
74-
# (this would require a different synchronization mechanism)
75-
end
76-
for t in tasks
77-
fetch(t)
78-
end
79-
end
80-
return
81-
end
82-
83-
8440
## executed on-device
8541

8642
# array type
@@ -108,42 +64,6 @@ end
10864
@inline Base.setindex!(A::JLDeviceArray, x, index::Integer) = setindex!(typed_data(A), x, index)
10965

11066

111-
# indexing
112-
113-
for f in (:blockidx, :blockdim, :threadidx, :griddim)
114-
@eval GPUArrays.$f(ctx::JLKernelContext) = ctx.$f
115-
end
116-
117-
# memory
118-
119-
function GPUArrays.LocalMemory(ctx::JLKernelContext, ::Type{T}, ::Val{dims}, ::Val{id}) where {T, dims, id}
120-
ctx.localmem_counter += 1
121-
lmems = ctx.localmems[blockidx(ctx)]
122-
123-
# first invocation in block
124-
data = if length(lmems) < ctx.localmem_counter
125-
lmem = fill(zero(T), dims)
126-
push!(lmems, lmem)
127-
lmem
128-
else
129-
lmems[ctx.localmem_counter]
130-
end
131-
132-
N = length(dims)
133-
JLDeviceArray{T,N}(data, tuple(dims...))
134-
end
135-
136-
# synchronization
137-
138-
@inline function GPUArrays.synchronize_threads(::JLKernelContext)
139-
# All threads are getting started asynchronously, so a yield will yield to the next
140-
# execution of the same function, which should call yield at the exact same point in the
141-
# program, leading to a chain of yields effectively syncing the tasks (threads).
142-
yield()
143-
return
144-
end
145-
146-
14767
#
14868
# Host abstractions
14969
#
@@ -409,8 +329,6 @@ end
409329

410330
## GPUArrays interfaces
411331

412-
GPUArrays.backend(::Type{<:JLArray}) = JLBackend()
413-
414332
Adapt.adapt_storage(::Adaptor, x::JLArray{T,N}) where {T,N} =
415333
JLDeviceArray{T,N}(x.data[], x.offset, x.dims)
416334

@@ -423,4 +341,50 @@ function GPUArrays.mapreducedim!(f, op, R::AnyJLArray, A::Union{AbstractArray,Br
423341
R
424342
end
425343

344+
## KernelAbstractions interface
345+
346+
KernelAbstractions.get_backend(a::JLA) where JLA <: JLArray = JLBackend()
347+
348+
function KernelAbstractions.mkcontext(kernel::Kernel{JLBackend}, I, _ndrange, iterspace, ::Dynamic) where Dynamic
349+
return KernelAbstractions.CompilerMetadata{KernelAbstractions.ndrange(kernel), Dynamic}(I, _ndrange, iterspace)
350+
end
351+
352+
KernelAbstractions.allocate(::JLBackend, ::Type{T}, dims::Tuple) where T = JLArray{T}(undef, dims)
353+
354+
@inline function launch_config(kernel::Kernel{JLBackend}, ndrange, workgroupsize)
355+
if ndrange isa Integer
356+
ndrange = (ndrange,)
357+
end
358+
if workgroupsize isa Integer
359+
workgroupsize = (workgroupsize, )
360+
end
361+
362+
if KernelAbstractions.workgroupsize(kernel) <: DynamicSize && workgroupsize === nothing
363+
workgroupsize = (1024,) # Vectorization, 4x unrolling, minimal grain size
364+
end
365+
iterspace, dynamic = partition(kernel, ndrange, workgroupsize)
366+
# partition checked that the ndrange's agreed
367+
if KernelAbstractions.ndrange(kernel) <: StaticSize
368+
ndrange = nothing
369+
end
370+
371+
return ndrange, workgroupsize, iterspace, dynamic
372+
end
373+
374+
KernelAbstractions.isgpu(b::JLBackend) = false
375+
376+
function convert_to_cpu(obj::Kernel{JLBackend, W, N, F}) where {W, N, F}
377+
return Kernel{typeof(KernelAbstractions.CPU(; static = obj.backend.static)), W, N, F}(KernelAbstractions.CPU(; static = obj.backend.static), obj.f)
378+
end
379+
380+
function (obj::Kernel{JLBackend})(args...; ndrange=nothing, workgroupsize=nothing)
381+
device_args = jlconvert.(args)
382+
new_obj = convert_to_cpu(obj)
383+
new_obj(device_args...; ndrange, workgroupsize)
384+
end
385+
386+
Adapt.adapt_storage(::JLBackend, a::Array) = Adapt.adapt(JLArrays.JLArray, a)
387+
Adapt.adapt_storage(::JLBackend, a::JLArrays.JLArray) = a
388+
Adapt.adapt_storage(::KernelAbstractions.CPU, a::JLArrays.JLArray) = convert(Array, a)
389+
426390
end

0 commit comments

Comments
 (0)