Skip to content

Commit d85578f

Browse files
committed
add uniform
1 parent f65a912 commit d85578f

File tree

3 files changed

+17
-6
lines changed

3 files changed

+17
-6
lines changed

src/KernelAbstractions.jl

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
module KernelAbstractions
22

33
export @kernel
4-
export @Const, @localmem, @private, @synchronize, @index, groupsize
4+
export @Const, @localmem, @private, @uniform, @synchronize, @index, groupsize
55
export Device, GPU, CPU, CUDA
66

77
using MacroTools
@@ -24,6 +24,7 @@ and then invoked on the arguments.
2424
- [`@index`](@ref)
2525
- [`@localmem`](@ref)
2626
- [`@private`](@ref)
27+
- [`@uniform`](@ref)
2728
- [`@synchronize`](@ref)
2829
2930
# Example:
@@ -69,6 +70,7 @@ function async_copy! end
6970
# Kernel language
7071
# - @localmem
7172
# - @private
73+
# - @uniform
7274
# - @synchronize
7375
# - @index
7476
# - groupsize
@@ -84,7 +86,7 @@ the total size you can use `prod(groupsize())`.
8486
function groupsize end
8587

8688
"""
87-
@localmem T dims
89+
@localmem T dims
8890
"""
8991
macro localmem(T, dims)
9092
# Stay in sync with CUDAnative
@@ -96,14 +98,21 @@ macro localmem(T, dims)
9698
end
9799

98100
"""
99-
@private T dims
101+
@private T dims
100102
"""
101103
macro private(T, dims)
102104
quote
103105
$Scratchpad($(esc(T)), Val($(esc(dims))))
104106
end
105107
end
106108

109+
"""
110+
@uniform value
111+
"""
112+
macro uniform(value)
113+
esc(value)
114+
end
115+
107116
"""
108117
@synchronize()
109118
"""

src/macros.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,8 @@ function split(stmts)
131131
push!(indicies, stmt)
132132
continue
133133
elseif callee === Symbol("@localmem") ||
134-
callee === Symbol("@private")
134+
callee === Symbol("@private") ||
135+
callee === Symbol("@uniform")
135136
push!(allocations, stmt)
136137
continue
137138
end

test/localmem.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,13 @@ if has_cuda_gpu()
77
end
88

99
@kernel function localmem(A)
10+
N = @uniform prod(groupsize())
1011
I = @index(Global, Linear)
1112
i = @index(Local, Linear)
12-
lmem = @localmem Int groupsize() # Ok iff groupsize is static
13+
lmem = @localmem Int (N,) # Ok iff groupsize is static
1314
lmem[i] = i
1415
@synchronize
15-
A[I] = lmem[prod(groupsize()) - i + 1]
16+
A[I] = lmem[N - i + 1]
1617
end
1718

1819
function harness(backend, ArrayT)

0 commit comments

Comments
 (0)