Skip to content

Commit ba2bdbf

Browse files
bors[bot]vchuravy
andauthored
Merge #38
38: add ntuple index type r=vchuravy a=vchuravy bors r+ Co-authored-by: Valentin Churavy <v.churavy@gmail.com>
2 parents 323f165 + 1143ad8 commit ba2bdbf

File tree

4 files changed

+17
-10
lines changed

4 files changed

+17
-10
lines changed

examples/matmul.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,15 +6,15 @@ end
66

77
# Simple kernel for matrix multiplication
88
@kernel function matmul_kernel!(a, b, c)
9-
cI = @index(Global, Cartesian)
9+
i, j = @index(Global, NTuple)
1010

1111
# creating a temporary sum variable for matrix multiplication
1212
tmp_sum = zero(eltype(c))
13-
for i = 1:size(a)[2]
14-
tmp_sum += a[cI[1],i] * b[i,cI[2]]
13+
for k = 1:size(a)[2]
14+
tmp_sum += a[i,k] * b[k, j]
1515
end
1616

17-
c[cI] = tmp_sum
17+
c[i,j] = tmp_sum
1818
end
1919

2020
# Creating a wrapper kernel for launching with error checks

examples/naive_transpose.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,7 @@ if CUDAapi.has_cuda_gpu()
55
end
66

77
@kernel function naive_transpose_kernel!(a, b)
8-
I = @index(Global, Cartesian)
9-
i, j = Tuple(I)
8+
i, j = @index(Global, NTuple)
109
@inbounds b[i, j] = a[j, i]
1110
end
1211

examples/performance.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,7 @@ using CUDAnative
99
using CUDAnative.NVTX
1010

1111
@kernel function transpose_kernel_naive!(b, a)
12-
I = @index(Global, Cartesian)
13-
i, j = I.I
12+
i, j = @index(Global, NTuple)
1413
@inbounds b[i, j] = a[j, i]
1514
end
1615

src/KernelAbstractions.jl

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,8 @@ A cartesian index is a general N-dimensional index that is derived from the iter
138138
# Index kind
139139
140140
- `Linear`: Produces an `Int64` that can be used to linearly index into memory.
141-
- `Global`: Produces a `CartesianIndex{N}` that can be used to index into memory.
141+
- `Cartesian`: Produces a `CartesianIndex{N}` that can be used to index into memory.
142+
- `NTuple`: Produces a `NTuple{N}` that can be used to index into memory.
142143
143144
If the index kind is not provided it defaults to `Linear`, this is suspect to change.
144145
@@ -149,6 +150,7 @@ If the index kind is not provided it defaults to `Linear`, this is suspect to ch
149150
@index(Global, Cartesian)
150151
@index(Local, Cartesian)
151152
@index(Group, Linear)
153+
@index(Local, NTuple)
152154
@index(Global)
153155
```
154156
"""
@@ -158,7 +160,10 @@ macro index(locale, args...)
158160
end
159161

160162
if length(args) >= 1
161-
if args[1] === :Cartesian || args[1] === :Linear
163+
164+
if args[1] === :Cartesian ||
165+
args[1] === :Linear ||
166+
args[1] === :NTuple
162167
indexkind = args[1]
163168
args = args[2:end]
164169
else
@@ -184,6 +189,10 @@ function __index_Local_Cartesian end
184189
function __index_Group_Cartesian end
185190
function __index_Global_Cartesian end
186191

192+
__index_Local_NTuple(I...) = Tuple(__index_Local_Cartesian(I...))
193+
__index_Group_NTuple(I...) = Tuple(__index_Group_Cartesian(I...))
194+
__index_Global_NTuple(I...) = Tuple(__index_Global_Cartesian(I...))
195+
187196
struct ConstAdaptor end
188197

189198
Adapt.adapt_storage(to::ConstAdaptor, a::Array) = Base.Experimental.Const(a)

0 commit comments

Comments
 (0)