Skip to content

Add accumulation for all backends using AcceleratedKernels.jl #606

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 9 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .buildkite/pipeline.yml
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ steps:
- label: "oneAPI.jl"
plugins:
- JuliaCI/julia#v1:
version: "1.10"
version: "1.11"
- JuliaCI/julia-coverage#v1:
codecov: true
command: |
Expand Down Expand Up @@ -95,7 +95,7 @@ steps:
- label: "Metal.jl"
plugins:
- JuliaCI/julia#v1:
version: "1.10"
version: "1.11"
- JuliaCI/julia-coverage#v1:
codecov: true
command: |
Expand Down
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ uuid = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7"
version = "11.2.3"

[deps]
AcceleratedKernels = "6a4ca0a5-0e36-4168-a932-d9be78d558f1"
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
Expand All @@ -22,6 +23,7 @@ JLD2 = "033835bb-8acc-5ee8-8aae-3f567f8a3819"
JLD2Ext = "JLD2"

[compat]
AcceleratedKernels = "0.4"
Adapt = "4.0"
GPUArraysCore = "= 0.2.0"
JLD2 = "0.4, 0.5"
Expand Down
37 changes: 37 additions & 0 deletions lib/JLArrays/src/JLArrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -347,6 +347,43 @@ function GPUArrays.mapreducedim!(f, op, R::AnyJLArray, A::Union{AbstractArray,Br
R
end

## Base interface

Base._accumulate!(op, output::AnyJLArray, input::AnyJLVector, dims::Nothing, init::Nothing) =
accumulate!(op, typed_data(output), typed_data(input); dims=1)

Base._accumulate!(op, output::AnyJLArray, input::AnyJLArray, dims::Integer, init::Nothing) =
accumulate!(op, typed_data(output), typed_data(input); dims)

Base._accumulate!(op, output::AnyJLArray, input::AnyJLVector, dims::Nothing, init::Some) =
accumulate!(op, typed_data(output), typed_data(input); dims=1, init=something(init))

Base._accumulate!(op, output::AnyJLArray, input::AnyJLArray, dims::Integer, init::Some) =
accumulate!(op, typed_data(output), typed_data(input); dims, init=something(init))

Base.accumulate_pairwise!(op, result::AnyJLVector, v::AnyJLVector) = accumulate!(op, result, v)

# default behavior unless dims are specified by the user
function Base.accumulate(op, A::AnyJLArray;
dims::Union{Nothing,Integer}=nothing, kw...)
nt = values(kw)
if dims === nothing && !(A isa AbstractVector)
# This branch takes care of the cases not handled by `_accumulate!`.
return reshape(accumulate(op, typed_data(A)[:]; kw...), size(A))
end
if isempty(kw)
out = similar(A, Base.promote_op(op, eltype(A), eltype(A)))
init = AK.neutral_element(op, eltype(out))
elseif keys(nt) === (:init,)
out = similar(A, Base.promote_op(op, typeof(nt.init), eltype(A)))
init = nt.init
else
throw(ArgumentError("accumulate does not support the keyword arguments $(setdiff(keys(nt), (:init,)))"))
end
accumulate!(op, typed_data(out), typed_data(A); dims, init)
end


## KernelAbstractions interface

KernelAbstractions.get_backend(a::JLA) where JLA <: JLArray = JLBackend()
Expand Down
2 changes: 2 additions & 0 deletions src/GPUArrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ using Reexport
@reexport using GPUArraysCore

using KernelAbstractions
import AcceleratedKernels as AK

# device functionality
include("device/abstractarray.jl")
Expand All @@ -27,6 +28,7 @@ include("host/construction.jl")
include("host/base.jl")
include("host/indexing.jl")
include("host/broadcast.jl")
include("host/accumulate.jl")
include("host/mapreduce.jl")
include("host/linalg.jl")
include("host/math.jl")
Expand Down
35 changes: 35 additions & 0 deletions src/host/accumulate.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
## Base interface

Base._accumulate!(op, output::AnyGPUArray, input::AnyGPUVector, dims::Nothing, init::Nothing) =
AK.accumulate!(op, output, input, get_backend(output); dims, init=AK.neutral_element(op, eltype(output)))

Base._accumulate!(op, output::AnyGPUArray, input::AnyGPUArray, dims::Integer, init::Nothing) =
AK.accumulate!(op, output, input, get_backend(output); dims, init=AK.neutral_element(op, eltype(output)))

Base._accumulate!(op, output::AnyGPUArray, input::AnyGPUVector, dims::Nothing, init::Some) =
AK.accumulate!(op, output, input, get_backend(output); dims, init=something(init))

Base._accumulate!(op, output::AnyGPUArray, input::AnyGPUArray, dims::Integer, init::Some) =
AK.accumulate!(op, output, input, get_backend(output); dims, init=something(init))

Base.accumulate_pairwise!(op, result::AnyGPUVector, v::AnyGPUVector) = accumulate!(op, result, v)

# default behavior unless dims are specified by the user
function Base.accumulate(op, A::AnyGPUArray;
dims::Union{Nothing,Integer}=nothing, kw...)
nt = values(kw)
if dims === nothing && !(A isa AbstractVector)
# This branch takes care of the cases not handled by `_accumulate!`.
return reshape(AK.accumulate(op, A[:], get_backend(A); init = (:init in keys(kw) ? nt.init : AK.neutral_element(op, eltype(A)))), size(A))
end
if isempty(kw)
out = similar(A, Base.promote_op(op, eltype(A), eltype(A)))
init = AK.neutral_element(op, eltype(out))
elseif keys(nt) === (:init,)
out = similar(A, Base.promote_op(op, typeof(nt.init), eltype(A)))
init = nt.init
else
throw(ArgumentError("accumulate does not support the keyword arguments $(setdiff(keys(nt), (:init,)))"))
end
AK.accumulate!(op, out, A, get_backend(A); dims, init)
end
3 changes: 3 additions & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,6 @@ REPL = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[sources]
JLArrays = {path = "/Users/christian/.julia/dev/GPUArrays/lib/JLArrays"}
1 change: 1 addition & 0 deletions test/testsuite.jl
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ include("testsuite/indexing.jl")
include("testsuite/base.jl")
include("testsuite/vector.jl")
include("testsuite/reductions.jl")
include("testsuite/accumulations.jl")
include("testsuite/broadcasting.jl")
include("testsuite/linalg.jl")
include("testsuite/math.jl")
Expand Down
108 changes: 108 additions & 0 deletions test/testsuite/accumulations.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
@testsuite "accumulations" (AT, eltypes)->begin
@testset "$ET" for ET in eltypes
range = ET <: Real ? (ET(1):ET(10)) : ET

# 1d arrays
for num_elems in 1:256
@test compare(A->accumulate(+, A; init=zero(ET)), AT, rand(range, num_elems))
end

for num_elems = rand(1:100, 10)
@test compare(A->accumulate(+, A; init=zero(ET)), AT, rand(range, num_elems))
end

for _ in 1:10 # nd arrays reduced as 1d
n1 = rand(1:10)
n2 = rand(1:10)
n3 = rand(1:10)
@test compare(A->accumulate(+, A; init=zero(ET)), AT, rand(range, n1, n2, n3))
end

for num_elems = rand(1:100, 10) # init value
init = rand(range)
@test compare(A->accumulate(+, A; init), AT, rand(range, num_elems))
end


# nd arrays
for dims in 1:4 # corner cases
for isize in 1:3
for jsize in 1:3
for ksize in 1:3
@test compare(A->accumulate(+, A; dims, init=zero(ET)), AT, rand(range, isize, jsize, ksize))
end
end
end
end

for _ in 1:10
for dims in 1:3
n1 = rand(1:10)
n2 = rand(1:10)
n3 = rand(1:10)
@test compare(A->accumulate(+, A; dims, init=zero(ET)), AT, rand(range, n1, n2, n3))
end
end

for _ in 1:10 # init value
for dims in 1:3
n1 = rand(1:10)
n2 = rand(1:10)
n3 = rand(1:10)
init = rand(range)
@test compare(A->accumulate(+, A; init, dims), AT, rand(range, n1, n2, n3))
end
end

# Larger containers to try and detect weird bugs
for n in (0, 1, 2, 3, 10, 10_000, 16384, 16384+1) # small, large, odd & even, pow2 and not
# Skip large tests on small datatypes
n >= 10000 && sizeof(real(ET)) <= 2 && continue

@test compare(x->accumulate(+, x), AT, rand(range, n))
@test compare(x->accumulate(+, x), AT, rand(range, n, 2))
@test compare(Base.Fix2((x,y)->accumulate(+, x; init=y), rand(range)), AT, rand(range, n))
end

# in place
@test compare(x->(accumulate!(+, x, copy(x)); x), AT, rand(range, 2))

@test_throws ArgumentError("accumulate does not support the keyword arguments [:bad_kwarg]") accumulate(+, AT(rand(ET, 10)); bad_kwarg="bad")
end
end

@testsuite "accumulations/cumsum & cumprod" (AT, eltypes)->begin
@test compare(cumsum, AT, rand(Bool, 16))

@testset "$ET" for ET in eltypes
range = ET <: Real ? (ET(1):ET(10)) : ET

# cumsum
for num_elems in rand(1:100, 10)
@test compare(A->cumsum(A; dims=1), AT, rand(range, num_elems))
end

for _ in 1:10
for dims in 1:3
n1 = rand(1:10)
n2 = rand(1:10)
n3 = rand(1:10)
@test compare(A->cumsum(A; dims), AT, rand(range, n1, n2, n3))
end
end


# cumprod
range = ET <: Real ? (ET(1):ET(10)) : ET
@test compare(A->cumprod(A; dims=1), AT, ones(ET, 100_000))

for _ in 1:10
for dims in 1:3
n1 = rand(1:10)
n2 = rand(1:10)
n3 = rand(1:10)
@test compare(A->cumprod(A; dims), AT, rand(range, n1, n2, n3))
end
end
end
end
Loading