Skip to content

Commit 28d96c1

Browse files
authored
Set codegen kwargs based on toplevel setting. (#600)
1 parent 7143d02 commit 28d96c1

File tree

3 files changed

+39
-49
lines changed

3 files changed

+39
-49
lines changed

src/driver.jl

Lines changed: 32 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -52,40 +52,37 @@ export compile
5252
const compile_hook = Ref{Union{Nothing,Function}}(nothing)
5353

5454
"""
55-
compile(target::Symbol, job::CompilerJob;
56-
libraries=true, optimize=true, strip=false, ...)
55+
compile(target::Symbol, job::CompilerJob; kwargs...)
5756
5857
Compile a function `f` invoked with types `tt` for device capability `cap` to one of the
5958
following formats as specified by the `target` argument: `:julia` for Julia IR, `:llvm` for
6059
LLVM IR and `:asm` for machine code.
6160
6261
The following keyword arguments are supported:
63-
- `libraries`: link the GPU runtime and `libdevice` libraries (if required)
64-
- `optimize`: optimize the code (default: true)
65-
- `cleanup`: run cleanup passes on the code (default: true)
62+
- `toplevel`: indicates that this compilation is the outermost invocation of the compiler
63+
(default: true)
64+
- `libraries`: link the GPU runtime and `libdevice` libraries (default: true, if toplevel)
65+
- `optimize`: optimize the code (default: true, if toplevel)
66+
- `cleanup`: run cleanup passes on the code (default: true, if toplevel)
67+
- `validate`: enable optional validation of input and outputs (default: true, if toplevel)
6668
- `strip`: strip non-functional metadata and debug information (default: false)
67-
- `validate`: enable optional validation of input and outputs (default: true)
6869
- `only_entry`: only keep the entry function, remove all others (default: false).
6970
This option is only for internal use, to implement reflection's `dump_module`.
7071
7172
Other keyword arguments can be found in the documentation of [`cufunction`](@ref).
7273
"""
73-
function compile(target::Symbol, @nospecialize(job::CompilerJob);
74-
libraries::Bool=true, toplevel::Bool=true,
75-
optimize::Bool=true, cleanup::Bool=true, strip::Bool=false,
76-
validate::Bool=true, only_entry::Bool=false)
74+
function compile(target::Symbol, @nospecialize(job::CompilerJob); kwargs...)
7775
if compile_hook[] !== nothing
7876
compile_hook[](job)
7977
end
8078

81-
return codegen(target, job;
82-
libraries, toplevel, optimize, cleanup, strip, validate, only_entry)
79+
return codegen(target, job; kwargs...)
8380
end
8481

85-
function codegen(output::Symbol, @nospecialize(job::CompilerJob);
86-
libraries::Bool=true, toplevel::Bool=true, optimize::Bool=true,
87-
cleanup::Bool=true, strip::Bool=false, validate::Bool=true,
88-
only_entry::Bool=false, parent_job::Union{Nothing, CompilerJob}=nothing)
82+
function codegen(output::Symbol, @nospecialize(job::CompilerJob); toplevel::Bool=true,
83+
libraries::Bool=toplevel, optimize::Bool=toplevel, cleanup::Bool=toplevel,
84+
validate::Bool=toplevel, strip::Bool=false, only_entry::Bool=false,
85+
parent_job::Union{Nothing, CompilerJob}=nothing)
8986
if context(; throw_error=false) === nothing
9087
error("No active LLVM context. Use `JuliaContext()` do-block syntax to create one.")
9188
end
@@ -159,9 +156,9 @@ end
159156

160157
const __llvm_initialized = Ref(false)
161158

162-
@locked function emit_llvm(@nospecialize(job::CompilerJob);
163-
libraries::Bool=true, toplevel::Bool=true, optimize::Bool=true,
164-
cleanup::Bool=true, only_entry::Bool=false, validate::Bool=true)
159+
@locked function emit_llvm(@nospecialize(job::CompilerJob); toplevel::Bool,
160+
libraries::Bool, optimize::Bool, cleanup::Bool,
161+
validate::Bool, only_entry::Bool)
165162
if !__llvm_initialized[]
166163
InitializeAllTargets()
167164
InitializeAllTargetInfos()
@@ -186,8 +183,7 @@ const __llvm_initialized = Ref(false)
186183
entry = finish_module!(job, ir, entry)
187184

188185
# deferred code generation
189-
has_deferred_jobs = !only_entry && toplevel &&
190-
haskey(functions(ir), "deferred_codegen")
186+
has_deferred_jobs = toplevel && !only_entry && haskey(functions(ir), "deferred_codegen")
191187
jobs = Dict{CompilerJob, String}(job => entry_fn)
192188
if has_deferred_jobs
193189
dyn_marker = functions(ir)["deferred_codegen"]
@@ -225,10 +221,8 @@ const __llvm_initialized = Ref(false)
225221
for dyn_job in keys(worklist)
226222
# cached compilation
227223
dyn_entry_fn = get!(jobs, dyn_job) do
228-
dyn_ir, dyn_meta = codegen(:llvm, dyn_job; validate=false,
229-
optimize=false,
230-
toplevel=false,
231-
parent_job=job)
224+
dyn_ir, dyn_meta = codegen(:llvm, dyn_job; toplevel=false,
225+
parent_job=job)
232226
dyn_entry_fn = LLVM.name(dyn_meta.entry)
233227
merge!(compiled, dyn_meta.compiled)
234228
@assert context(dyn_ir) == context(ir)
@@ -264,27 +258,24 @@ const __llvm_initialized = Ref(false)
264258
unsafe_delete!(ir, dyn_marker)
265259
end
266260

267-
if toplevel
268-
# always preload the runtime, and do so early; it cannot be part of any
269-
# timing block because it recurses into the compiler
270-
if !uses_julia_runtime(job) && libraries
261+
if libraries
262+
# load the runtime outside of a timing block (because it recurses into the compiler)
263+
if !uses_julia_runtime(job)
271264
runtime = load_runtime(job)
272265
runtime_fns = LLVM.name.(defs(runtime))
273266
runtime_intrinsics = ["julia.gc_alloc_obj"]
274267
end
275268

276269
@timeit_debug to "Library linking" begin
277-
if libraries
278-
# target-specific libraries
279-
undefined_fns = LLVM.name.(decls(ir))
280-
@timeit_debug to "target libraries" link_libraries!(job, ir, undefined_fns)
281-
282-
# GPU run-time library
283-
if !uses_julia_runtime(job) && any(fn -> fn in runtime_fns ||
284-
fn in runtime_intrinsics,
285-
undefined_fns)
286-
@timeit_debug to "runtime library" link_library!(ir, runtime)
287-
end
270+
# target-specific libraries
271+
undefined_fns = LLVM.name.(decls(ir))
272+
@timeit_debug to "target libraries" link_libraries!(job, ir, undefined_fns)
273+
274+
# GPU run-time library
275+
if !uses_julia_runtime(job) && any(fn -> fn in runtime_fns ||
276+
fn in runtime_intrinsics,
277+
undefined_fns)
278+
@timeit_debug to "runtime library" link_library!(ir, runtime)
288279
end
289280
end
290281
end
@@ -434,7 +425,7 @@ const __llvm_initialized = Ref(false)
434425
end
435426

436427
@locked function emit_asm(@nospecialize(job::CompilerJob), ir::LLVM.Module;
437-
strip::Bool=false, validate::Bool=true, format::LLVM.API.LLVMCodeGenFileType)
428+
strip::Bool, validate::Bool, format::LLVM.API.LLVMCodeGenFileType)
438429
# NOTE: strip after validation to get better errors
439430
if strip
440431
@timeit_debug to "Debug info removal" strip_debuginfo!(ir)

src/precompile.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -56,10 +56,10 @@ function _precompile_()
5656
@assert precompile(Tuple{typeof(GPUCompiler.lower_gc_frame!),LLVM.Function})
5757
@assert precompile(Tuple{typeof(GPUCompiler.lower_throw!),LLVM.Module})
5858
#@assert precompile(Tuple{typeof(GPUCompiler.split_kwargs),Tuple{},Vector{Symbol},Vararg{Vector{Symbol}, N} where N})
59-
let fbody = try __lookup_kwbody__(which(GPUCompiler.compile, (Symbol,GPUCompiler.CompilerJob,))) catch missing end
60-
if !ismissing(fbody)
61-
@assert precompile(fbody, (Bool,Bool,Bool,Bool,Bool,Bool,Bool,typeof(GPUCompiler.compile),Symbol,GPUCompiler.CompilerJob,))
62-
@assert precompile(fbody, (Bool,Bool,Bool,Bool,Bool,Bool,Bool,typeof(GPUCompiler.compile),Symbol,GPUCompiler.CompilerJob,))
63-
end
64-
end
59+
# let fbody = try __lookup_kwbody__(which(GPUCompiler.compile, (Symbol,GPUCompiler.CompilerJob,))) catch missing end
60+
# if !ismissing(fbody)
61+
# @assert precompile(fbody, (Bool,Bool,Bool,Bool,Bool,Bool,Bool,typeof(GPUCompiler.compile),Symbol,GPUCompiler.CompilerJob,))
62+
# @assert precompile(fbody, (Bool,Bool,Bool,Bool,Bool,Bool,Bool,typeof(GPUCompiler.compile),Symbol,GPUCompiler.CompilerJob,))
63+
# end
64+
# end
6565
end

src/rtlib.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,8 +56,7 @@ end
5656
function emit_function!(mod, config::CompilerConfig, f, method)
5757
tt = Base.to_tuple_type(method.types)
5858
source = generic_methodinstance(f, tt)
59-
new_mod, meta = codegen(:llvm, CompilerJob(source, config);
60-
optimize=false, libraries=false, validate=false)
59+
new_mod, meta = codegen(:llvm, CompilerJob(source, config); toplevel=false)
6160
ft = function_type(meta.entry)
6261
expected_ft = convert(LLVM.FunctionType, method)
6362
if return_type(ft) != return_type(expected_ft)

0 commit comments

Comments
 (0)