Skip to content

Commit b5709d7

Browse files
authored
Merge pull request #25 from JuliaGPU/vc/val
Handle type parameters in kernel functions
2 parents 1497d41 + d54367c commit b5709d7

File tree

3 files changed

+31
-4
lines changed

3 files changed

+31
-4
lines changed

src/backends/cpu.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,12 +59,12 @@ end
5959
Cassette.@context CPUCtx
6060

6161
function mkcontext(kernel::Kernel{CPU}, I, _ndrange, _workgroupsize)
62-
metadata = CompilerMetadata{workgroupsize(kernel), ndrange(kernel), false}(I, _ndrange, workgroupsize)
62+
metadata = CompilerMetadata{workgroupsize(kernel), ndrange(kernel), false}(I, _ndrange, _workgroupsize)
6363
Cassette.disablehooks(CPUCtx(pass = CompilerPass, metadata=metadata))
6464
end
6565

6666
function mkcontextdynamic(kernel::Kernel{CPU}, I, _ndrange, _workgroupsize)
67-
metadata = CompilerMetadata{workgroupsize(kernel), ndrange(kernel), true}(I, _ndrange, workgroupsize)
67+
metadata = CompilerMetadata{workgroupsize(kernel), ndrange(kernel), true}(I, _ndrange, _workgroupsize)
6868
Cassette.disablehooks(CPUCtx(pass = CompilerPass, metadata=metadata))
6969
end
7070

src/macros.jl

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,15 @@ function __kernel(expr)
77
body = expr.args[2]
88

99
# parse decl
10-
@assert isexpr(decl, :call)
10+
# `@kernel fname(::T) where {T}`
11+
if isexpr(decl, :where)
12+
iswhere = true
13+
whereargs = decl.args[2:end]
14+
decl = decl.args[1]
15+
else
16+
iswhere = false
17+
end
18+
@assert isexpr(decl, :call)
1119
name = decl.args[1]
1220

1321
# List of tuple (Symbol, Bool) where the bool
@@ -37,6 +45,11 @@ function __kernel(expr)
3745
gpu_decl = Expr(:call, gpu_name, arglist...)
3846
cpu_decl = Expr(:call, cpu_name, arglist...)
3947

48+
if iswhere
49+
gpu_decl = Expr(:where, gpu_decl, whereargs...)
50+
cpu_decl = Expr(:where, cpu_decl, whereargs...)
51+
end
52+
4053
# Without the deepcopy we might accidentially modify expr shared between CPU and GPU
4154
gpu_body = transform_gpu(deepcopy(body), args)
4255
gpu_function = Expr(:function, gpu_decl, gpu_body)

test/test.jl

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -167,4 +167,18 @@ end
167167
@test occursin("@llvm.nvvm.ldg", IR)
168168
end
169169
end
170-
end
170+
end
171+
172+
@kernel function kernel_val!(a, ::Val{m}) where {m}
173+
I = @index(Global)
174+
@inbounds a[I] = m
175+
end
176+
177+
A = zeros(Int64, 1024)
178+
wait(kernel_val!(CPU())(A,Val(3), ndrange=size(A)))
179+
@test all((a)->a==3, A)
180+
if has_cuda_gpu()
181+
A = CuArrays.zeros(Int64, 1024)
182+
wait(kernel_val!(CUDA())(A,Val(3), ndrange=size(A)))
183+
@test all((a)->a==3, A)
184+
end

0 commit comments

Comments
 (0)