Skip to content

Commit 1c6a87c

Browse files
authored
Use KernelAbstractions.jl for gather/scatter kernels (#487)
* Use KA for gather * Finish scatter * Update testsuite for gather/scatter * Cleanup * Update tests * Retain NNlibCUDA scatter kernels * Fixup * Add at-inbounds * Add compat * Use KA unsafe free
1 parent ee909e6 commit 1c6a87c

File tree

12 files changed

+536
-469
lines changed

12 files changed

+536
-469
lines changed

Project.toml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,9 @@ version = "0.8.19"
44

55
[deps]
66
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
7+
Atomix = "a9b6321e-bd34-4604-b9c9-b65b8de01458"
78
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
9+
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
810
KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
911
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1012
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
@@ -21,8 +23,10 @@ NNlibAMDGPUExt = "AMDGPU"
2123
[compat]
2224
AMDGPU = "0.4.8"
2325
Adapt = "2, 3.2"
26+
Atomix = "0.1"
2427
ChainRulesCore = "1.13"
25-
KernelAbstractions = "0.9"
28+
GPUArraysCore = "0.1"
29+
KernelAbstractions = "0.9.2"
2630
Requires = "0.5, 1.0"
2731
julia = "1.6"
2832

ext/NNlibCUDA/src/NNlibCUDA.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@ include("batchedmul.jl")
1313
include("ctc.jl")
1414
include("fold.jl")
1515
include("scatter.jl")
16-
include("gather.jl")
1716
include("utils.jl")
1817
include("cudnn/cudnn.jl")
1918
include("cudnn/conv.jl")

ext/NNlibCUDA/src/gather.jl

Lines changed: 0 additions & 65 deletions
This file was deleted.

ext/NNlibCUDA/src/scatter.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ function scatter_kernel!(op::OP, dst, src, idx, max_idx, max_dims_idx, dims_size
3030
return nothing
3131
end
3232

33-
function scatter_kernel!(op::OP, dst, src, idx::CUDA.CuDeviceArray{<:CartesianIndex},
33+
function scatter_kernel!(op::OP, dst, src, idx::CUDA.CuDeviceArray{<:CartesianIndex},
3434
max_idx, max_dims_idx, dims_size) where OP
3535
index = threadIdx().x + (blockIdx().x - 1) * blockDim().x
3636

@@ -73,7 +73,7 @@ end
7373

7474
## Gradients
7575

76-
function ∇scatter_src_kernel!(op::OP, Δsrc, src, idx,
76+
function ∇scatter_src_kernel!(op::OP, Δsrc, src, idx,
7777
rev_idx, max_idx, T::Type{TT}) where {OP,TT}
7878
index = threadIdx().x + (blockIdx().x - 1) * blockDim().x
7979

@@ -93,7 +93,7 @@ function ∇scatter_src_kernel!(op::OP, Δsrc, src, idx,
9393
return nothing
9494
end
9595

96-
function ∇scatter_src_kernel!(op::OP, Δsrc, src, idx::CUDA.CuDeviceArray{<:CartesianIndex},
96+
function ∇scatter_src_kernel!(op::OP, Δsrc, src, idx::CUDA.CuDeviceArray{<:CartesianIndex},
9797
rev_idx, max_idx, T::Type{TT}) where {OP,TT}
9898
index = threadIdx().x + (blockIdx().x - 1) * blockDim().x
9999

@@ -113,7 +113,7 @@ function ∇scatter_src_kernel!(op::OP, Δsrc, src, idx::CUDA.CuDeviceArray{<:Ca
113113
return nothing
114114
end
115115

116-
function ∇scatter_src_kernel!(op::OP, Δsrc, src, idx,
116+
function ∇scatter_src_kernel!(op::OP, Δsrc, src, idx,
117117
rev_idx, pre_cart_idx, max_dims_idx, max_idx, T::Type{TT}) where {OP,TT}
118118
index = threadIdx().x + (blockIdx().x - 1) * blockDim().x
119119

@@ -160,13 +160,13 @@ function ∇scatter_src_kernel!(op::OP, Δsrc, src, idx::CUDA.CuDeviceArray{<:Ca
160160
end
161161

162162
function NNlib.∇scatter_src(op::Union{typeof(*),typeof(/)}, Δ, dst,
163-
src::AnyCuArray{Tsrc,Nsrc},
163+
src::AnyCuArray{Tsrc,Nsrc},
164164
idx::AnyCuArray{Tidx,Nidx}) where {Tsrc,Tidx,Nsrc,Nidx}
165165
dims = Nsrc - Nidx
166166
Δsrc = NNlib.modify_src(op, NNlib.gather(Δ, idx), src)
167167
rev_idx = NNlib.reverse_indices(idx)
168168
rev_idx = CuArray(map(CUDA.cudaconvert, rev_idx))
169-
169+
170170
if dims == 0
171171
max_idx = length(idx)
172172
args = op, Δsrc, src, idx, rev_idx, max_idx, Tsrc

ext/NNlibCUDA/test/gather.jl

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
1-
@testset "gather" begin
1+
@testset "gather" begin
22
T = Float32
33
CT = CuArray{Float32}
4-
4+
55
## 1d src, 2d index of ints -> 2d output
66
src = CT([3, 4, 5, 6, 7])
77
index = cu([1 2 3 4;
@@ -10,14 +10,14 @@
1010
output = CT([3 4 5 6;
1111
6 4 3 5;
1212
5 7 7 5])
13-
13+
1414
y = NNlib.gather(src, index)
1515
@test y isa CuArray{Float32,2}
1616
@test size(y) == size(index)
1717
gputest(src -> NNlib.gather(src, index), src, checkgrad=true)
1818
@test NNlib.gather!(CUDA.zeros(T, size(index)...), src, index) == output
1919
@test_throws ArgumentError NNlib.gather!(zeros(T, 3, 5), src, index)
20-
20+
2121
## 1d src, 2d index of tuples -> 2d output
2222
src = CT([3, 4, 5, 6, 7])
2323
index = cu([(1,) (2,) (3,) (4,);
@@ -26,14 +26,14 @@
2626
output = CT([3 4 5 6;
2727
6 4 3 5;
2828
5 7 7 5])
29-
29+
3030
y = NNlib.gather(src, index)
3131
@test y isa CuArray{Float32,2}
3232
@test size(y) == size(index)
3333
gputest(src -> NNlib.gather(src, index), src, checkgrad=true)
3434
@test NNlib.gather!(CUDA.zeros(T, size(index)...), src, index) == output
3535
@test_throws ArgumentError NNlib.gather!(zeros(T, 3, 5), src, index)
36-
36+
3737
## 1d src, 2d index of CartesianIndex -> 2d output
3838
src = CT([3, 4, 5, 6, 7])
3939
index = cu(CartesianIndex.([(1,) (2,) (3,) (4,);
@@ -42,7 +42,7 @@
4242
output = CT([3 4 5 6;
4343
6 4 3 5;
4444
5 7 7 5])
45-
45+
4646
y = NNlib.gather(src, index)
4747
@test y isa CuArray{Float32,2}
4848
@test size(y) == size(index)
@@ -66,7 +66,7 @@
6666

6767

6868
## 2d src, 2d index of ints -> 3d output
69-
src = CT([3 5 7
69+
src = CT([3 5 7
7070
4 6 8])
7171
index = cu([1 2 3;
7272
2 2 1;
@@ -79,14 +79,14 @@
7979

8080
output[:,:,2] = [5 5 3
8181
6 6 4]
82-
82+
8383
output[:,:,3] = [7 3 7
8484
8 4 8]
85-
85+
8686
y = NNlib.gather(src, index)
8787
M = NNlib.typelength(eltype(index))
8888
Nsrc = ndims(src)
8989
@test y isa CuArray{Float32,3}
90-
@test size(y) == (size(src)[1:Nsrc-M]..., size(index)...)
90+
@test size(y) == (size(src)[1:Nsrc-M]..., size(index)...)
9191
gputest(src -> NNlib.gather(src, index), src, checkgrad=true)
9292
end

src/NNlib.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
module NNlib
22

3+
import Atomix
34
import ChainRulesCore: rrule
45

56
using Base.Broadcast: broadcasted
67
using Base.Threads
78
using ChainRulesCore
9+
using GPUArraysCore
810
using KernelAbstractions
911
using KernelAbstractions: @atomic
1012
using LinearAlgebra

0 commit comments

Comments
 (0)