Skip to content

Commit 584811f

Browse files
authored
Move arguments to compile/codegen into the CompilerConfig struct (#668)
1 parent 2f1db84 commit 584811f

File tree

17 files changed

+199
-126
lines changed

17 files changed

+199
-126
lines changed

src/driver.jl

Lines changed: 49 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -43,64 +43,59 @@ end
4343

4444
export compile
4545

46-
# NOTE: the keyword arguments to compile/codegen control those aspects of compilation that
47-
# might have to be changed (e.g. set libraries=false when recursing, or set
48-
# strip=true for reflection). What remains defines the compilation job itself,
49-
# and those values are contained in the CompilerJob struct.
50-
5146
# (::CompilerJob)
5247
const compile_hook = Ref{Union{Nothing,Function}}(nothing)
5348

5449
"""
55-
compile(target::Symbol, job::CompilerJob; kwargs...)
56-
57-
Compile a function `f` invoked with types `tt` for device capability `cap` to one of the
58-
following formats as specified by the `target` argument: `:julia` for Julia IR, `:llvm` for
59-
LLVM IR and `:asm` for machine code.
60-
61-
The following keyword arguments are supported:
62-
- `toplevel`: indicates that this compilation is the outermost invocation of the compiler
63-
(default: true)
64-
- `libraries`: link the GPU runtime and `libdevice` libraries (default: true, if toplevel)
65-
- `optimize`: optimize the code (default: true, if toplevel)
66-
- `cleanup`: run cleanup passes on the code (default: true, if toplevel)
67-
- `validate`: enable optional validation of input and outputs (default: true, if toplevel)
68-
- `strip`: strip non-functional metadata and debug information (default: false)
69-
- `only_entry`: only keep the entry function, remove all others (default: false).
70-
This option is only for internal use, to implement reflection's `dump_module`.
71-
72-
Other keyword arguments can be found in the documentation of [`cufunction`](@ref).
50+
compile(target::Symbol, job::CompilerJob)
51+
52+
Compile a `job` to one of the following formats as specified by the `target` argument:
53+
`:julia` for Julia IR, `:llvm` for LLVM IR and `:asm` for machine code.
7354
"""
7455
function compile(target::Symbol, @nospecialize(job::CompilerJob); kwargs...)
56+
# XXX: remove on next major version
57+
if !isempty(kwargs)
58+
Base.depwarn("The GPUCompiler `compile` API does not take keyword arguments anymore. Use CompilerConfig instead.", :compile)
59+
config = CompilerConfig(job.config; kwargs...)
60+
job = CompilerJob(job.source, config)
61+
end
62+
7563
if compile_hook[] !== nothing
7664
compile_hook[](job)
7765
end
7866

79-
return codegen(target, job; kwargs...)
67+
return compile_unhooked(target, job)
8068
end
8169

82-
function codegen(output::Symbol, @nospecialize(job::CompilerJob); toplevel::Bool=true,
83-
libraries::Bool=toplevel, optimize::Bool=toplevel, cleanup::Bool=toplevel,
84-
validate::Bool=toplevel, strip::Bool=false, only_entry::Bool=false,
85-
parent_job::Union{Nothing, CompilerJob}=nothing)
70+
# XXX: remove on next major version
71+
function codegen(output::Symbol, @nospecialize(job::CompilerJob); kwargs...)
72+
if !isempty(kwargs)
73+
Base.depwarn("The GPUCompiler `codegen` function is an internal API. Use `GPUCompiler.compile` (with any kwargs passed to `CompilerConfig`) instead.", :codegen)
74+
config = CompilerConfig(job.config; kwargs...)
75+
job = CompilerJob(job.source, config)
76+
end
77+
compile_unhooked(output, job)
78+
end
79+
80+
function compile_unhooked(output::Symbol, @nospecialize(job::CompilerJob); kwargs...)
8681
if context(; throw_error=false) === nothing
8782
error("No active LLVM context. Use `JuliaContext()` do-block syntax to create one.")
8883
end
8984

9085
@timeit_debug to "Validation" begin
9186
check_method(job) # not optional
92-
validate && check_invocation(job)
87+
job.config.validate && check_invocation(job)
9388
end
9489

9590
prepare_job!(job)
9691

9792

9893
## LLVM IR
9994

100-
ir, ir_meta = emit_llvm(job; libraries, toplevel, optimize, cleanup, only_entry, validate)
95+
ir, ir_meta = emit_llvm(job)
10196

10297
if output == :llvm
103-
if strip
98+
if job.config.strip
10499
@timeit_debug to "strip debug info" strip_debuginfo!(ir)
105100
end
106101

@@ -117,7 +112,7 @@ function codegen(output::Symbol, @nospecialize(job::CompilerJob); toplevel::Bool
117112
else
118113
error("Unknown assembly format $output")
119114
end
120-
asm, asm_meta = emit_asm(job, ir; strip, validate, format)
115+
asm, asm_meta = emit_asm(job, ir, format)
121116

122117
if output == :asm || output == :obj
123118
return asm, (; asm_meta..., ir_meta..., ir)
@@ -156,9 +151,14 @@ end
156151

157152
const __llvm_initialized = Ref(false)
158153

159-
@locked function emit_llvm(@nospecialize(job::CompilerJob); toplevel::Bool,
160-
libraries::Bool, optimize::Bool, cleanup::Bool,
161-
validate::Bool, only_entry::Bool)
154+
@locked function emit_llvm(@nospecialize(job::CompilerJob); kwargs...)
155+
# XXX: remove on next major version
156+
if !isempty(kwargs)
157+
Base.depwarn("The GPUCompiler `emit_llvm` function is an internal API. Use `GPUCompiler.compile` (with any kwargs passed to `CompilerConfig`) instead.", :emit_llvm)
158+
config = CompilerConfig(job.config; kwargs...)
159+
job = CompilerJob(job.source, config)
160+
end
161+
162162
if !__llvm_initialized[]
163163
InitializeAllTargets()
164164
InitializeAllTargetInfos()
@@ -183,7 +183,8 @@ const __llvm_initialized = Ref(false)
183183
entry = finish_module!(job, ir, entry)
184184

185185
# deferred code generation
186-
has_deferred_jobs = toplevel && !only_entry && haskey(functions(ir), "deferred_codegen")
186+
has_deferred_jobs = job.config.toplevel && !job.config.only_entry &&
187+
haskey(functions(ir), "deferred_codegen")
187188
jobs = Dict{CompilerJob, String}(job => entry_fn)
188189
if has_deferred_jobs
189190
dyn_marker = functions(ir)["deferred_codegen"]
@@ -221,8 +222,8 @@ const __llvm_initialized = Ref(false)
221222
for dyn_job in keys(worklist)
222223
# cached compilation
223224
dyn_entry_fn = get!(jobs, dyn_job) do
224-
dyn_ir, dyn_meta = codegen(:llvm, dyn_job; toplevel=false,
225-
parent_job=job)
225+
config = CompilerConfig(dyn_job.config; toplevel=false)
226+
dyn_ir, dyn_meta = codegen(:llvm, CompilerJob(dyn_job; config))
226227
dyn_entry_fn = LLVM.name(dyn_meta.entry)
227228
merge!(compiled, dyn_meta.compiled)
228229
@assert context(dyn_ir) == context(ir)
@@ -258,7 +259,7 @@ const __llvm_initialized = Ref(false)
258259
erase!(dyn_marker)
259260
end
260261

261-
if libraries
262+
if job.config.toplevel && job.config.libraries
262263
# load the runtime outside of a timing block (because it recurses into the compiler)
263264
if !uses_julia_runtime(job)
264265
runtime = load_runtime(job)
@@ -284,7 +285,7 @@ const __llvm_initialized = Ref(false)
284285
# mark everything internal except for entrypoints and any exported
285286
# global variables. this makes sure that the optimizer can, e.g.,
286287
# rewrite function signatures.
287-
if toplevel
288+
if job.config.toplevel
288289
preserved_gvs = collect(values(jobs))
289290
for gvar in globals(ir)
290291
if linkage(gvar) == LLVM.API.LLVMExternalLinkage
@@ -310,7 +311,7 @@ const __llvm_initialized = Ref(false)
310311
# so that we can reconstruct the CompileJob instead of setting it globally
311312
end
312313

313-
if optimize
314+
if job.config.toplevel && job.config.optimize
314315
@timeit_debug to "optimization" begin
315316
optimize!(job, ir; job.config.opt_level)
316317

@@ -337,7 +338,7 @@ const __llvm_initialized = Ref(false)
337338
entry = functions(ir)[entry_fn]
338339
end
339340

340-
if cleanup
341+
if job.config.toplevel && job.config.cleanup
341342
@timeit_debug to "clean-up" begin
342343
@dispose pb=NewPMPassBuilder() begin
343344
add!(pb, RecomputeGlobalsAAPass())
@@ -355,7 +356,7 @@ const __llvm_initialized = Ref(false)
355356
# we want to finish the module after optimization, so we cannot do so
356357
# during deferred code generation. instead, process the deferred jobs
357358
# here.
358-
if toplevel
359+
if job.config.toplevel
359360
entry = finish_ir!(job, ir, entry)
360361

361362
for (job′, fn′) in jobs
@@ -367,7 +368,7 @@ const __llvm_initialized = Ref(false)
367368
# replace non-entry function definitions with a declaration
368369
# NOTE: we can't do this before optimization, because the definitions of called
369370
# functions may affect optimization.
370-
if only_entry
371+
if job.config.only_entry
371372
for f in functions(ir)
372373
f == entry && continue
373374
isdeclaration(f) && continue
@@ -377,7 +378,7 @@ const __llvm_initialized = Ref(false)
377378
end
378379
end
379380

380-
if validate
381+
if job.config.toplevel && job.config.validate
381382
@timeit_debug to "Validation" begin
382383
check_ir(job, ir)
383384
end
@@ -390,10 +391,10 @@ const __llvm_initialized = Ref(false)
390391
return ir, (; entry, compiled)
391392
end
392393

393-
@locked function emit_asm(@nospecialize(job::CompilerJob), ir::LLVM.Module;
394-
strip::Bool, validate::Bool, format::LLVM.API.LLVMCodeGenFileType)
394+
@locked function emit_asm(@nospecialize(job::CompilerJob), ir::LLVM.Module,
395+
format::LLVM.API.LLVMCodeGenFileType)
395396
# NOTE: strip after validation to get better errors
396-
if strip
397+
if job.config.strip
397398
@timeit_debug to "Debug info removal" strip_debuginfo!(ir)
398399
end
399400

src/execution.jl

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,20 @@ export split_kwargs, assign_args!
88
# split keyword arguments expressions into groups. returns vectors of keyword argument
99
# values, one more than the number of groups (unmatched keywords in the last vector).
1010
# intended for use in macros; the resulting groups can be used in expressions.
11+
# can be used at run time, but not in performance critical code.
1112
function split_kwargs(kwargs, kw_groups...)
1213
kwarg_groups = ntuple(_->[], length(kw_groups) + 1)
1314
for kwarg in kwargs
1415
# decode
15-
Meta.isexpr(kwarg, :(=)) || throw(ArgumentError("non-keyword argument like option '$kwarg'"))
16-
key, val = kwarg.args
16+
if Meta.isexpr(kwarg, :(=))
17+
# use in macros
18+
key, val = kwarg.args
19+
elseif kwarg isa Pair{Symbol,<:Any}
20+
# use in functions
21+
key, val = kwarg
22+
else
23+
throw(ArgumentError("non-keyword argument like option '$kwarg'"))
24+
end
1725
isa(key, Symbol) || throw(ArgumentError("non-symbolic keyword '$key'"))
1826

1927
# find a matching group
@@ -182,7 +190,7 @@ end
182190
end
183191

184192
struct DiskCacheEntry
185-
src::Type # Originally MethodInstance, but upon deserialize they were not uniqued...
193+
src::Type # Originally MethodInstance, but upon deserialize they were not uniqued...
186194
cfg::CompilerConfig
187195
asm
188196
end
@@ -262,7 +270,16 @@ end
262270
obj = linker(job, asm)
263271

264272
if ci === nothing
265-
ci = ci_cache_lookup(ci_cache(job), src, world, world)::CodeInstance
273+
ci = ci_cache_lookup(ci_cache(job), src, world, world)
274+
if ci === nothing
275+
error("""Did not find CodeInstance for $job.
276+
277+
Pleaase make sure that the `compiler` function passed to `cached_compilation`
278+
invokes GPUCompiler with exactly the same configuration as passed to the API.
279+
280+
Note that you should do this by calling `GPUCompiler.compile`, and not by
281+
using reflection functions (which alter the compiler configuration).""")
282+
end
266283
key = (ci, cfg)
267284
end
268285
cache[key] = obj

0 commit comments

Comments
 (0)