Skip to content

Commit 42cc70e

Browse files
leiosvchuravy
andauthored
Unified printing (#61)
* adding random draft for casette overdubbing of print function. * CPU tests passing for @print macro * testing print statement on GPU devices. * removing Manifest from PR * Update src/KernelAbstractions.jl Co-Authored-By: Valentin Churavy <vchuravy@users.noreply.github.com> * Update src/macros.jl Co-Authored-By: Valentin Churavy <vchuravy@users.noreply.github.com> * Update src/macros.jl Co-Authored-By: Valentin Churavy <vchuravy@users.noreply.github.com> * specifying test to run on backend instead of boolean flag Co-authored-by: Valentin Churavy <vchuravy@users.noreply.github.com>
1 parent c3b0428 commit 42cc70e

File tree

5 files changed

+109
-1
lines changed

5 files changed

+109
-1
lines changed

src/KernelAbstractions.jl

Lines changed: 48 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
module KernelAbstractions
22

33
export @kernel
4-
export @Const, @localmem, @private, @uniform, @synchronize, @index, groupsize
4+
export @Const, @localmem, @private, @uniform, @synchronize, @index, groupsize, @print
55
export Device, GPU, CPU, CUDA, Event
66
export async_copy!
77

@@ -28,6 +28,7 @@ and then invoked on the arguments.
2828
- [`@private`](@ref)
2929
- [`@uniform`](@ref)
3030
- [`@synchronize`](@ref)
31+
- [`@print`](@ref)
3132
3233
# Example:
3334
@@ -147,6 +148,48 @@ macro synchronize(cond)
147148
end
148149
end
149150

151+
"""
152+
@print(items...)
153+
154+
This is a unified print statement.
155+
156+
# Platform differences
157+
- `GPU`: This will reorganize the items to print via @cuprintf
158+
- `CPU`: This will call `print(items...)`
159+
"""
160+
macro print(items...)
161+
162+
args = Union{Val,Expr,Symbol}[]
163+
164+
items = [items...]
165+
while true
166+
isempty(items) && break
167+
168+
item = popfirst!(items)
169+
170+
# handle string interpolation
171+
if isa(item, Expr) && item.head == :string
172+
items = vcat(item.args, items)
173+
continue
174+
end
175+
176+
# expose literals to the generator by using Val types
177+
if isbits(item) # literal numbers, etc
178+
push!(args, Val(item))
179+
elseif isa(item, QuoteNode) # literal symbols
180+
push!(args, Val(item.value))
181+
elseif isa(item, String) # literal strings need to be interned
182+
push!(args, Val(Symbol(item)))
183+
else # actual values that will be passed to printf
184+
push!(args, item)
185+
end
186+
end
187+
188+
quote
189+
$__print($(map(esc,args)...))
190+
end
191+
end
192+
150193
"""
151194
@index
152195
@@ -345,6 +388,10 @@ function __synchronize()
345388
error("@synchronize used outside kernel or not captured")
346389
end
347390

391+
function __print(items...)
392+
error("@print used outside of kernel")
393+
end
394+
348395
###
349396
# Backends/Implementation
350397
###

src/backends/cpu.jl

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,30 @@ end
155155
end
156156
end
157157

158+
@generated function _print(items...)
159+
str = ""
160+
args = []
161+
162+
for i in 1:length(items)
163+
item = :(items[$i])
164+
T = items[i]
165+
if T <: Val
166+
item = QuoteNode(T.parameters[1])
167+
end
168+
push!(args, item)
169+
end
170+
171+
quote
172+
print($(args...))
173+
end
174+
175+
end
176+
177+
178+
@inline function Cassette.overdub(ctx::CPUCtx, ::typeof(__print), items...)
179+
_print(items...)
180+
end
181+
158182
generate_overdubs(CPUCtx)
159183

160184
###

src/backends/cuda.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -274,6 +274,10 @@ end
274274
CUDAnative.sync_threads()
275275
end
276276

277+
@inline function Cassette.overdub(ctx::CUDACtx, ::typeof(__print), args...)
278+
CUDAnative._cuprint(args...)
279+
end
280+
277281
###
278282
# GPU implementation of `@Const`
279283
###

test/print_test.jl

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
using KernelAbstractions, Test
2+
using CUDAapi
3+
if CUDAapi.has_cuda_gpu()
4+
using CuArrays
5+
CuArrays.allowscalar(false)
6+
end
7+
8+
@kernel function kernel_print()
9+
I = @index(Global)
10+
@print("Hello from thread ", I, "!\n")
11+
end
12+
13+
function test_print(backend)
14+
if backend == CPU()
15+
kernel = kernel_print(CPU(), 4)
16+
else
17+
kernel = kernel_print(CUDA(), 4)
18+
end
19+
kernel(ndrange=(4,))
20+
end
21+
22+
23+
@testset "print test" begin
24+
if CUDAapi.has_cuda_gpu()
25+
wait(test_print(CUDA()))
26+
@test true
27+
end
28+
29+
wait(test_print(CPU()))
30+
@test true
31+
end

test/runtests.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,4 +21,6 @@ end
2121
include("nditeration.jl")
2222
end
2323

24+
include("print_test.jl")
25+
2426
include("examples.jl")

0 commit comments

Comments
 (0)