Skip to content

Commit 15494f3

Browse files
authored
Validate inputs and outputs earlier. (#390)
This makes sure we catch invalid arguments before, e.g., processing them in the Metal back-end.
1 parent 3ad8bcc commit 15494f3

File tree

3 files changed

+21
-16
lines changed

3 files changed

+21
-16
lines changed

src/driver.jl

Lines changed: 19 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ The following keyword arguments are supported:
2525
- `optimize`: optimize the code (default: true)
2626
- `cleanup`: run cleanup passes on the code (default: true)
2727
- `strip`: strip non-functional metadata and debug information (default: false)
28-
- `validate`: validate the generated IR before emitting machine code (default: true)
28+
- `validate`: enable optional validation of input and outputs (default: true)
2929
- `only_entry`: only keep the entry function, remove all others (default: false).
3030
This option is only for internal use, to implement reflection's `dump_module`.
3131
@@ -90,7 +90,7 @@ function codegen(output::Symbol, @nospecialize(job::CompilerJob);
9090
ctx::Union{JuliaContextType,Nothing}=nothing)
9191
## Julia IR
9292

93-
mi, mi_meta = emit_julia(job)
93+
mi, mi_meta = emit_julia(job; validate)
9494

9595
if output == :julia
9696
return mi, mi_meta
@@ -112,7 +112,7 @@ function codegen(output::Symbol, @nospecialize(job::CompilerJob);
112112
Use a JuliaContext instead.""")
113113
end
114114

115-
ir, ir_meta = emit_llvm(job, mi; libraries, deferred_codegen, optimize, cleanup, only_entry, ctx)
115+
ir, ir_meta = emit_llvm(job, mi; libraries, deferred_codegen, optimize, cleanup, only_entry, validate, ctx)
116116

117117
if output == :llvm
118118
if strip
@@ -148,8 +148,11 @@ function codegen(output::Symbol, @nospecialize(job::CompilerJob);
148148
end
149149
end
150150

151-
@locked function emit_julia(@nospecialize(job::CompilerJob))
152-
@timeit_debug to "validation" check_method(job)
151+
@locked function emit_julia(@nospecialize(job::CompilerJob); validate::Bool=true)
152+
@timeit_debug to "Validation" begin
153+
check_method(job) # not optional
154+
validate && check_invocation(job)
155+
end
153156

154157
@timeit_debug to "Julia front-end" begin
155158

@@ -201,7 +204,8 @@ const __llvm_initialized = Ref(false)
201204

202205
@locked function emit_llvm(@nospecialize(job::CompilerJob), @nospecialize(method_instance);
203206
libraries::Bool=true, deferred_codegen::Bool=true, optimize::Bool=true,
204-
cleanup::Bool=true, only_entry::Bool=false, ctx::JuliaContextType)
207+
cleanup::Bool=true, only_entry::Bool=false, validate::Bool=true,
208+
ctx::JuliaContextType)
205209
if !__llvm_initialized[]
206210
InitializeAllTargets()
207211
InitializeAllTargetInfos()
@@ -293,8 +297,10 @@ const __llvm_initialized = Ref(false)
293297
for dyn_job in keys(worklist)
294298
# cached compilation
295299
dyn_entry_fn = get!(deferred_jobs, dyn_job) do
296-
dyn_ir, dyn_meta = codegen(:llvm, dyn_job; optimize=false,
297-
deferred_codegen=false, parent_job=job, ctx)
300+
dyn_ir, dyn_meta = codegen(:llvm, dyn_job; validate=false,
301+
optimize=false,
302+
deferred_codegen=false,
303+
parent_job=job, ctx)
298304
dyn_entry_fn = LLVM.name(dyn_meta.entry)
299305
merge!(compiled, dyn_meta.compiled)
300306
@assert context(dyn_ir) == unwrap_context(ctx)
@@ -407,18 +413,17 @@ const __llvm_initialized = Ref(false)
407413
end
408414
end
409415

410-
return ir, (; entry, compiled)
411-
end
412-
413-
@locked function emit_asm(@nospecialize(job::CompilerJob), ir::LLVM.Module;
414-
strip::Bool=false, validate::Bool=true, format::LLVM.API.LLVMCodeGenFileType)
415416
if validate
416417
@timeit_debug to "Validation" begin
417-
check_invocation(job)
418418
check_ir(job, ir)
419419
end
420420
end
421421

422+
return ir, (; entry, compiled)
423+
end
424+
425+
@locked function emit_asm(@nospecialize(job::CompilerJob), ir::LLVM.Module;
426+
strip::Bool=false, validate::Bool=true, format::LLVM.API.LLVMCodeGenFileType)
422427
# NOTE: strip after validation to get better errors
423428
if strip
424429
@timeit_debug to "Debug info removal" strip_debuginfo!(ir)

src/rtlib.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ end
6666
function emit_function!(mod, @nospecialize(job::CompilerJob), f, method; ctx::JuliaContextType)
6767
tt = Base.to_tuple_type(method.types)
6868
new_mod, meta = codegen(:llvm, similar(job, FunctionSpec(f, tt, #=kernel=# false));
69-
optimize=false, libraries=false, ctx)
69+
optimize=false, libraries=false, validate=false, ctx)
7070
ft = eltype(llvmtype(meta.entry))
7171
expected_ft = convert(LLVM.FunctionType, method; ctx=context(new_mod))
7272
if LLVM.return_type(ft) != LLVM.return_type(expected_ft)

test/native.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ end
8888
job, _ = native_job(foo, (Float64,))
8989
JuliaContext() do ctx
9090
# shouldn't segfault
91-
ir, meta = GPUCompiler.compile(:llvm, job; ctx)
91+
ir, meta = GPUCompiler.compile(:llvm, job; ctx, validate=false)
9292

9393
meth = only(methods(foo, (Float64,)))
9494

0 commit comments

Comments
 (0)