Skip to content

Commit 5769a8e

Browse files
authored
Allow return statements for GPU-only kernels (#538)
1 parent 419481c commit 5769a8e

File tree

2 files changed

+21
-1
lines changed

2 files changed

+21
-1
lines changed

src/macros.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,9 @@ function __kernel(expr, generate_cpu = true, force_inbounds = false)
1414
def = splitdef(expr)
1515
name = def[:name]
1616
args = def[:args]
17-
find_return(expr) && error("Return statement not permitted in a kernel function $name")
17+
generate_cpu && find_return(expr) && error(
18+
"Return statement not permitted in a kernel function $name",
19+
)
1820

1921
constargs = Array{Bool}(undef, length(args))
2022
for (i, arg) in enumerate(args)

test/test.jl

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -306,4 +306,22 @@ function unittest_testsuite(Backend, backend_str, backend_mod, BackendArrayT; sk
306306
@test size(KernelAbstractions.zeros(backend, Float32, 0, 9)) == (0, 9)
307307
end
308308

309+
@kernel cpu = false function gpu_return_kernel!(x)
310+
i = @index(Global)
311+
if i (length(x) ÷ 2)
312+
x[i] = 1
313+
return
314+
end
315+
end
316+
@testset "GPU kernel return statement" begin
317+
if !(Backend() isa CPU)
318+
A = KernelAbstractions.zeros(Backend(), Int64, 1024)
319+
gpu_return_kernel!(Backend())(A; ndrange = length(A))
320+
synchronize(Backend())
321+
Ah = Array(A)
322+
@test all(a -> a == 1, @view(Ah[1:(length(A) ÷ 2)]))
323+
@test all(a -> a == 0, @view(Ah[(length(A) ÷ 2 + 1):end]))
324+
end
325+
end
326+
309327
end

0 commit comments

Comments
 (0)