Skip to content

Commit 7643c4a

Browse files
committed
handle type parameters in kernels
1 parent 1497d41 commit 7643c4a

File tree

2 files changed

+29
-2
lines changed

2 files changed

+29
-2
lines changed

src/macros.jl

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,15 @@ function __kernel(expr)
77
body = expr.args[2]
88

99
# parse decl
10-
@assert isexpr(decl, :call)
10+
# `@kernel fname(::T) where {T}`
11+
if isexpr(decl, :where)
12+
iswhere = true
13+
whereargs = decl.args[2:end]
14+
decl = decl.args[1]
15+
else
16+
iswhere = false
17+
end
18+
@assert isexpr(decl, :call)
1119
name = decl.args[1]
1220

1321
# List of tuple (Symbol, Bool) where the bool
@@ -37,6 +45,11 @@ function __kernel(expr)
3745
gpu_decl = Expr(:call, gpu_name, arglist...)
3846
cpu_decl = Expr(:call, cpu_name, arglist...)
3947

48+
if iswhere
49+
gpu_decl = Expr(:where, gpu_decl, whereargs...)
50+
cpu_decl = Expr(:where, cpu_decl, whereargs...)
51+
end
52+
4053
# Without the deepcopy we might accidentially modify expr shared between CPU and GPU
4154
gpu_body = transform_gpu(deepcopy(body), args)
4255
gpu_function = Expr(:function, gpu_decl, gpu_body)

test/test.jl

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -167,4 +167,18 @@ end
167167
@test occursin("@llvm.nvvm.ldg", IR)
168168
end
169169
end
170-
end
170+
end
171+
172+
@kernel function kernel_val!(a, ::Val{m}) where {m}
173+
I = @index(Global)
174+
@inbounds a[I] = m
175+
end
176+
177+
A = zeros(Int64, 1024)
178+
wait(kernel_val!(CPU())(A,Val(3), ndrange=size(A)))
179+
@test all((a)->a==3, A)
180+
if has_cuda_gpu()
181+
A = CuArrays.zeros(Int64, 1024)
182+
wait(kernel_val!(CUDA())(A,Val(3), ndrange=size(A)))
183+
@test all((a)->a==3, A)
184+
end

0 commit comments

Comments
 (0)