Skip to content

Commit b662c70

Browse files
committed
Detect device-side exceptions on the host.
1 parent 7344ad8 commit b662c70

File tree

5 files changed

+126
-2
lines changed

5 files changed

+126
-2
lines changed

src/OpenCL.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ include("array.jl")
3131

3232
# compiler implementation
3333
include("compiler/compilation.jl")
34+
include("compiler/exceptions.jl")
3435
include("compiler/execution.jl")
3536
include("compiler/reflection.jl")
3637

src/compiler/compilation.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@ GPUCompiler.runtime_module(::CompilerJob{<:Any,OpenCLCompilerParams}) = OpenCL
99
GPUCompiler.method_table_view(job::OpenCLCompilerJob) =
1010
GPUCompiler.StackedMethodTable(job.world, method_table, SPIRVIntrinsics.method_table)
1111

12+
GPUCompiler.kernel_state_type(job::OpenCLCompilerJob) = KernelState
13+
1214
# filter out OpenCL built-ins
1315
# TODO: eagerly lower these using the translator API
1416
GPUCompiler.isintrinsic(job::OpenCLCompilerJob, fn::String) =

src/compiler/exceptions.jl

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
# support for device-side exceptions
2+
3+
## exception type
4+
5+
struct KernelException <: Exception
6+
dev::cl.Device
7+
end
8+
9+
function Base.showerror(io::IO, err::KernelException)
10+
print(io, "KernelException: exception thrown during kernel execution on device $(err.dev.name)")
11+
end
12+
13+
14+
## exception handling
15+
16+
const exception_infos = Dict{cl.Context, Union{Nothing, cl.AbstractPointerMemory}}()
17+
18+
# create a CPU/GPU exception flag for error signalling
19+
function create_exceptions!(ctx::cl.Context, dev::cl.Device)
20+
mem = get!(exception_infos, ctx) do
21+
if cl.svm_capabilities(cl.device()).fine_grain_buffer
22+
cl.svm_alloc(sizeof(ExceptionInfo_st); fine_grained=true)
23+
elseif cl.usm_supported(dev) && cl.usm_capabilities(dev).shared.access
24+
cl.shared_alloc(sizeof(ExceptionInfo_st); placement=:host)
25+
else
26+
nothing
27+
end
28+
end
29+
if mem === nothing
30+
return convert(ExceptionInfo, C_NULL)
31+
end
32+
33+
exception_info = convert(ExceptionInfo, mem)
34+
unsafe_store!(exception_info, ExceptionInfo_st())
35+
return exception_info
36+
end
37+
38+
# check the exception flags on every API call
39+
function check_exceptions()
40+
for (ctx, mem) in exception_infos
41+
mem === nothing && continue
42+
exception_info = convert(ExceptionInfo, mem)
43+
if exception_info.status != 0
44+
# restore the structure
45+
unsafe_store!(exception_info, ExceptionInfo_st())
46+
47+
# throw host-side
48+
dev = cl.device(ctx)
49+
throw(KernelException(dev))
50+
end
51+
end
52+
return
53+
end

src/compiler/execution.jl

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,10 @@ abstract type AbstractKernel{F, TT} end
143143
call_t = Type[x[1] for x in zip(sig.parameters, to_pass) if x[2]]
144144
call_args = Union{Expr,Symbol}[x[1] for x in zip(args, to_pass) if x[2]]
145145

146+
# add the kernel state as the first argument
147+
pushfirst!(call_t, KernelState)
148+
pushfirst!(call_args, :(kernel.state))
149+
146150
# replace non-isbits arguments (they should be unused, or compilation would have failed)
147151
for (i,dt) in enumerate(call_t)
148152
if !isbitstype(dt)
@@ -156,7 +160,20 @@ abstract type AbstractKernel{F, TT} end
156160

157161
quote
158162
indirect_memory = cl.AbstractMemory[]
163+
164+
# add exception info buffer to indirect memory
165+
# XXX: this is too expensive
166+
if kernel.state.exception_info != C_NULL
167+
ctx = cl.context()
168+
if haskey(exception_infos, ctx)
169+
push!(indirect_memory, exception_infos[ctx])
170+
end
171+
end
172+
159173
clcall(kernel.fun, $call_tt, $(call_args...); indirect_memory, call_kwargs...)
174+
175+
# check for exceptions after kernel execution
176+
check_exceptions()
160177
end
161178
end
162179

@@ -167,6 +184,7 @@ end
167184
struct HostKernel{F,TT} <: AbstractKernel{F,TT}
168185
f::F
169186
fun::cl.Kernel
187+
state::KernelState
170188
end
171189

172190

@@ -191,7 +209,9 @@ function clfunction(f::F, tt::TT=Tuple{}; kwargs...) where {F,TT}
191209
kernel = get(_kernel_instances, h, nothing)
192210
if kernel === nothing
193211
# create the kernel state object
194-
kernel = HostKernel{F,tt}(f, fun)
212+
exception_info = create_exceptions!(ctx, dev)
213+
state = KernelState(exception_info)
214+
kernel = HostKernel{F,tt}(f, fun, state)
195215
_kernel_instances[h] = kernel
196216
end
197217
return kernel::HostKernel{F,tt}

src/device/runtime.jl

Lines changed: 49 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,55 @@
11
# reset the runtime cache from global scope, so that any change triggers recompilation
22
GPUCompiler.reset_runtime()
33

4-
signal_exception() = return
4+
## exception handling
5+
6+
struct ExceptionInfo_st
7+
# whether an exception has been encountered (0 -> 1)
8+
status::Int32
9+
10+
ExceptionInfo_st() = new(0)
11+
end
12+
13+
# to simplify use of this struct, which is passed by-reference, use property overloading
14+
const ExceptionInfo = Ptr{ExceptionInfo_st}
15+
@inline function Base.getproperty(info::ExceptionInfo, sym::Symbol)
16+
if sym === :status
17+
unsafe_load(convert(Ptr{Int32}, info))
18+
else
19+
getfield(info, sym)
20+
end
21+
end
22+
@inline function Base.setproperty!(info::ExceptionInfo, sym::Symbol, value)
23+
if sym === :status
24+
unsafe_store!(convert(Ptr{Int32}, info), value)
25+
else
26+
setfield!(info, sym, value)
27+
end
28+
end
29+
30+
## kernel state
31+
32+
struct KernelState
33+
exception_info::ExceptionInfo
34+
35+
# XXX: Intel's SPIR-V compiler does not support array-valued kernel arguments, and Julia
36+
# emits homogeneous structs as arrays. Work around this by including a dummy field.
37+
dummy::UInt32
38+
end
39+
KernelState(exception_info::ExceptionInfo) = KernelState(exception_info, 42)
40+
41+
@inline @generated kernel_state() = GPUCompiler.kernel_state_value(KernelState)
42+
43+
function signal_exception()
44+
info = kernel_state().exception_info
45+
46+
# inform the host
47+
if info != C_NULL
48+
info.status = 1
49+
end
50+
51+
return
52+
end
553

654
malloc(sz) = C_NULL
755

0 commit comments

Comments
 (0)