Skip to content

Commit 8cbc86e

Browse files
authored
Merge pull request #64 from JuliaGPU/vc/cleanup
make at_print work outside KA
2 parents bc0a63e + f375223 commit 8cbc86e

File tree

3 files changed

+21
-28
lines changed

3 files changed

+21
-28
lines changed

src/KernelAbstractions.jl

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -409,8 +409,22 @@ function __synchronize()
409409
error("@synchronize used outside kernel or not captured")
410410
end
411411

412-
function __print(items...)
413-
error("@print used outside of kernel")
412+
@generated function __print(items...)
413+
str = ""
414+
args = []
415+
416+
for i in 1:length(items)
417+
item = :(items[$i])
418+
T = items[i]
419+
if T <: Val
420+
item = QuoteNode(T.parameters[1])
421+
end
422+
push!(args, item)
423+
end
424+
425+
quote
426+
print($(args...))
427+
end
414428
end
415429

416430
###

src/backends/cpu.jl

Lines changed: 1 addition & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -153,28 +153,9 @@ end
153153
end
154154
end
155155

156-
@generated function _print(items...)
157-
str = ""
158-
args = []
159-
160-
for i in 1:length(items)
161-
item = :(items[$i])
162-
T = items[i]
163-
if T <: Val
164-
item = QuoteNode(T.parameters[1])
165-
end
166-
push!(args, item)
167-
end
168-
169-
quote
170-
print($(args...))
171-
end
172-
173-
end
174-
175156

176157
@inline function Cassette.overdub(ctx::CPUCtx, ::typeof(__print), items...)
177-
_print(items...)
158+
__print(items...)
178159
end
179160

180161
generate_overdubs(CPUCtx)

test/print_test.jl

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,15 +11,10 @@ end
1111
end
1212

1313
function test_print(backend)
14-
if backend == CPU()
15-
kernel = kernel_print(CPU(), 4)
16-
else
17-
kernel = kernel_print(CUDA(), 4)
18-
end
14+
kernel = kernel_print(backend, 4)
1915
kernel(ndrange=(4,))
2016
end
2117

22-
2318
@testset "print test" begin
2419
if CUDAapi.has_cuda_gpu()
2520
wait(test_print(CUDA()))
@@ -28,4 +23,7 @@ end
2823

2924
wait(test_print(CPU()))
3025
@test true
26+
27+
@print("Why this should work")
28+
@test true
3129
end

0 commit comments

Comments
 (0)