Skip to content

Commit c51dab3

Browse files
authored
Split CompilerJob in dynamic and static part. (#395)
1 parent 860ec6a commit c51dab3

24 files changed

+206
-188
lines changed

examples/kernel.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,8 @@ function main()
1919
source = FunctionSpec(typeof(kernel), Tuple{})
2020
target = NativeCompilerTarget()
2121
params = TestCompilerParams()
22-
job = CompilerJob(source, target, params)
22+
config = CompilerConfig(target, params)
23+
job = CompilerJob(config, source)
2324

2425
println(GPUCompiler.compile(:asm, job)[1])
2526
end

src/bpf.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,4 +32,4 @@ const bpf_intrinsics = () # TODO
3232
isintrinsic(::CompilerJob{BPFCompilerTarget}, fn::String) = in(fn, bpf_intrinsics)
3333

3434
valid_function_pointer(job::CompilerJob{BPFCompilerTarget}, ptr::Ptr{Cvoid}) =
35-
reinterpret(UInt, ptr) in job.target.function_pointers
35+
reinterpret(UInt, ptr) in job.config.target.function_pointers

src/cache.jl

Lines changed: 49 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ end
130130
const cache_lock = ReentrantLock()
131131

132132
"""
133-
cached_compilation(cache::Dict, job::CompilerJob, compiler, linker)
133+
cached_compilation(cache::Dict{UInt}, job::CompilerJob, compiler, linker)
134134
135135
Compile `job` using `compiler` and `linker`, and store the result in `cache`.
136136
@@ -140,46 +140,57 @@ and return data that can be cached across sessions (e.g., LLVM IR). This data is
140140
forwarded, along with the `CompilerJob`, to the `linker` function which is allowed to create
141141
session-dependent objects (e.g., a `CuModule`).
142142
"""
143-
function cached_compilation(cache::AbstractDict,
144-
@nospecialize(job::CompilerJob),
145-
compiler::Function, linker::Function)
146-
# NOTE: it is OK to index the compilation cache directly with the compilation job, i.e.,
147-
# using a world age instead of intersecting world age ranges, because we expect
148-
# that the world age is aquired through calling `get_world` and thus will only
149-
# ever change when the kernel function is redefined.
150-
#
151-
# if we ever want to be able to index the cache using a compilation job that
152-
# contains a more recent world age, yet still return an older cached object that
153-
# would still be valid, we'd need the cache to store world ranges instead and
154-
# use an invalidation callback to add upper bounds to entries.
155-
key = hash(job)
156-
157-
force_compilation = compile_hook[] !== nothing
158-
159-
# NOTE: no use of lock(::Function)/@lock/get! to keep stack traces clean
143+
function cached_compilation(cache::AbstractDict{UInt,V},
144+
cfg::CompilerConfig,
145+
ft::Type, tt::Type,
146+
compiler::Function, linker::Function) where {V}
147+
# NOTE: it is OK to index the compilation cache directly with the world age, instead of
148+
# intersecting world age ranges, because we the world age is aquired by calling
149+
# `get_world` and thus will only change when the kernel function is redefined.
150+
world = get_world(ft, tt)
151+
key = hash(ft)
152+
key = hash(tt, key)
153+
key = hash(world, key)
154+
key = hash(cfg, key)
155+
156+
# NOTE: no use of lock(::Function)/@lock/get! to avoid try/catch and closure overhead
160157
lock(cache_lock)
161-
try
162-
obj = get(cache, key, nothing)
163-
if obj === nothing || force_compilation
164-
asm = nothing
158+
obj = get(cache, key, nothing)
159+
unlock(cache_lock)
165160

166-
# compile
167-
if asm === nothing
168-
if compile_hook[] !== nothing
169-
compile_hook[](job)
170-
end
171-
172-
asm = compiler(job)
173-
end
174-
175-
# link (but not if we got here because of forced compilation)
176-
if obj === nothing
177-
obj = linker(job, asm)
178-
cache[key] = obj
179-
end
161+
LLVM.Interop.assume(isassigned(compile_hook))
162+
if obj === nothing || compile_hook[] !== nothing
163+
obj = actual_compilation(cache, key, cfg, ft, tt, world, compiler, linker)::V
164+
end
165+
return obj::V
166+
end
167+
168+
@noinline function actual_compilation(cache::AbstractDict, key::UInt,
169+
cfg::CompilerConfig,
170+
ft::Type, tt::Type, world,
171+
compiler::Function, linker::Function)
172+
src = FunctionSpec(ft, tt, world)
173+
job = CompilerJob(cfg, src)
174+
175+
asm = nothing
176+
# TODO: consider loading the assembly from an on-disk cache here
177+
178+
# compile
179+
if asm === nothing
180+
if compile_hook[] !== nothing
181+
compile_hook[](job)
180182
end
183+
184+
asm = compiler(job)
185+
end
186+
187+
# link (but not if we got here because of forced compilation,
188+
# in which case the cache will already be populated)
189+
lock(cache_lock) do
190+
haskey(cache, key) && return cache[key]
191+
192+
obj = linker(job, asm)
193+
cache[key] = obj
181194
obj
182-
finally
183-
unlock(cache_lock)
184195
end
185196
end

src/driver.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -227,7 +227,7 @@ const __llvm_initialized = Ref(false)
227227

228228
@timeit_debug to "IR generation" begin
229229
ir, compiled = irgen(job, method_instance; ctx)
230-
if job.entry_abi === :specfunc
230+
if job.config.entry_abi === :specfunc
231231
entry_fn = compiled[method_instance].specfunc
232232
else
233233
entry_fn = compiled[method_instance].func
@@ -300,11 +300,11 @@ const __llvm_initialized = Ref(false)
300300

301301
# get a job in the appopriate world
302302
dyn_job = if dyn_val isa CompilerJob
303-
dyn_spec = FunctionSpec(dyn_val.source; world=job.source.world)
304-
CompilerJob(dyn_val; source=dyn_spec)
303+
dyn_src = FunctionSpec(dyn_val.source; world=job.source.world)
304+
CompilerJob(dyn_val.config, dyn_src)
305305
elseif dyn_val isa FunctionSpec
306-
dyn_spec = FunctionSpec(dyn_val; world=job.source.world)
307-
CompilerJob(job; source=dyn_spec)
306+
dyn_src = FunctionSpec(dyn_val; world=job.source.world)
307+
CompilerJob(job.config, dyn_src)
308308
else
309309
error("invalid deferred job type $(typeof(dyn_val))")
310310
end
@@ -349,7 +349,7 @@ const __llvm_initialized = Ref(false)
349349

350350
@timeit_debug to "IR post-processing" begin
351351
# mark the kernel entry-point functions (optimization may need it)
352-
if job.source.kernel
352+
if job.config.kernel
353353
push!(metadata(ir)["julia.kernel"], MDNode([entry]; ctx=unwrap_context(ctx)))
354354

355355
# IDEA: save all jobs, not only kernels, and save other attributes

src/gcn.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -30,15 +30,15 @@ end
3030

3131
# TODO: encode debug build or not in the compiler job
3232
# https://github.com/JuliaGPU/CUDAnative.jl/issues/368
33-
runtime_slug(job::CompilerJob{GCNCompilerTarget}) = "gcn-$(job.target.dev_isa)$(job.target.features)"
33+
runtime_slug(job::CompilerJob{GCNCompilerTarget}) = "gcn-$(job.config.target.dev_isa)$(job.config.target.features)"
3434

3535
const gcn_intrinsics = () # TODO: ("vprintf", "__assertfail", "malloc", "free")
3636
isintrinsic(::CompilerJob{GCNCompilerTarget}, fn::String) = in(fn, gcn_intrinsics)
3737

3838
function process_entry!(job::CompilerJob{GCNCompilerTarget}, mod::LLVM.Module, entry::LLVM.Function)
3939
entry = invoke(process_entry!, Tuple{CompilerJob, LLVM.Module, LLVM.Function}, job, mod, entry)
4040

41-
if job.source.kernel
41+
if job.config.kernel
4242
# calling convention
4343
callconv!(entry, LLVM.API.LLVMAMDGPUKERNELCallConv)
4444
end
@@ -54,7 +54,7 @@ function finish_module!(@nospecialize(job::CompilerJob{GCNCompilerTarget}),
5454
mod::LLVM.Module, entry::LLVM.Function)
5555
entry = invoke(finish_module!, Tuple{CompilerJob, LLVM.Module, LLVM.Function}, job, mod, entry)
5656

57-
if job.source.kernel
57+
if job.config.kernel
5858
# work around bad byval codegen (JuliaGPU/GPUCompiler.jl#92)
5959
entry = lower_byval(job, mod, entry)
6060
end
@@ -84,10 +84,10 @@ end
8484
function optimize_module!(job::CompilerJob{GCNCompilerTarget}, mod::LLVM.Module)
8585
@static if VERSION < v"1.9.0-DEV.1018"
8686
# revert back to the AMDGPU target
87-
triple!(mod, llvm_triple(job.target))
88-
datalayout!(mod, julia_datalayout(job.target))
87+
triple!(mod, llvm_triple(job.config.target))
88+
datalayout!(mod, julia_datalayout(job.config.target))
8989

90-
tm = llvm_machine(job.target)
90+
tm = llvm_machine(job.config.target)
9191
@dispose pm=ModulePassManager() begin
9292
add_library_info!(pm, triple(mod))
9393
add_transform_info!(pm, tm)

src/interface.jl

Lines changed: 48 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -66,34 +66,24 @@ struct FunctionSpec
6666
tt::Type
6767
world::UInt
6868

69-
kernel::Bool
70-
name::Union{Nothing,String}
71-
72-
FunctionSpec(ft::Type, tt::Type, world::Integer=get_world(ft, tt);
73-
kernel=true, name=nothing) =
74-
new(ft, tt, world, kernel, name)
69+
FunctionSpec(ft::Type, tt::Type, world::Integer=get_world(ft, tt)) =
70+
new(ft, tt, world)
7571
end
7672

7773
# copy constructor
78-
FunctionSpec(spec::FunctionSpec; ft=spec.ft, tt=spec.tt, world=spec.world,
79-
kernel=spec.kernel, name=spec.name) =
80-
FunctionSpec(ft, tt, world; kernel, name)
74+
FunctionSpec(spec::FunctionSpec; ft=spec.ft, tt=spec.tt, world=spec.world) =
75+
FunctionSpec(ft, tt, world)
8176

8277
function Base.hash(spec::FunctionSpec, h::UInt)
8378
h = hash(spec.ft, h)
8479
h = hash(spec.tt, h)
8580
h = hash(spec.world, h)
8681

87-
h = hash(spec.kernel, h)
88-
h = hash(spec.name, h)
89-
9082
return h
9183
end
9284

9385
function signature(@nospecialize(spec::FunctionSpec))
94-
fn = if spec.name !== nothing
95-
spec.name
96-
elseif spec.ft.name.mt == Symbol.name.mt
86+
fn = if spec.ft.name.mt == Symbol.name.mt
9787
# uses shared method table, so name is not unique to this function type
9888
nameof(spec.ft)
9989
else
@@ -104,22 +94,21 @@ function signature(@nospecialize(spec::FunctionSpec))
10494
end
10595

10696
function Base.show(io::IO, @nospecialize(spec::FunctionSpec))
107-
spec.kernel ? print(io, "kernel ") : print(io, "function ")
10897
print(io, signature(spec), " in world ", spec.world)
10998
end
11099

111100

112-
## job
101+
## config
113102

114-
export CompilerJob
103+
export CompilerConfig
115104

116-
# a specific invocation of the compiler, bundling everything needed to generate code
105+
# the configuration of the compiler
117106

118107
"""
119-
CompilerJob(source, target, params; entry_abi=:specfunc, always_inline=false)
108+
CompilerConfig(target, params; kernel=true, entry_abi=:specfunc, always_inline=false)
120109
121-
Construct a `CompilerJob` for `source` that will be used to drive compilation for the given
122-
`target` and `params`.
110+
Construct a `CompilerConfig` that will be used to drive compilation for the given `target`
111+
and `params`.
123112
124113
The `entry_abi` can be either `:specfunc` the default, or `:func`. `:specfunc` expects the
125114
arguments to be passed in registers, simple return values are returned in registers as well,
@@ -133,46 +122,62 @@ generally easier to invoke directly.
133122
`always_inline` specifies if the Julia front-end should inline all functions into one if
134123
possible.
135124
"""
136-
struct CompilerJob{T,P}
125+
struct CompilerConfig{T,P}
137126
target::T
138127
params::P
139-
source::FunctionSpec
140128

129+
kernel::Bool
141130
entry_abi::Symbol
142131
always_inline::Bool
143132

144-
function CompilerJob(source::FunctionSpec,
145-
target::AbstractCompilerTarget,
146-
params::AbstractCompilerParams;
147-
entry_abi::Symbol=:specfunc, always_inline=false)
133+
function CompilerConfig(target::AbstractCompilerTarget,
134+
params::AbstractCompilerParams;
135+
kernel::Bool=true,
136+
entry_abi::Symbol=:specfunc,
137+
always_inline=false)
148138
if entry_abi (:specfunc, :func)
149139
error("Unknown entry_abi=$entry_abi")
150140
end
151-
new{typeof(target), typeof(params)}(target, params, source, entry_abi, always_inline)
141+
new{typeof(target), typeof(params)}(target, params, kernel, entry_abi, always_inline)
152142
end
153143
end
154144

155145
# copy constructor
156-
CompilerJob(job::CompilerJob; source=job.source, target=job.target, params=job.params,
157-
entry_abi=job.entry_abi, always_inline=job.always_inline) =
158-
CompilerJob(source, target, params; entry_abi, always_inline)
146+
CompilerConfig(cfg::CompilerConfig; target=cfg.target, params=cfg.params,
147+
kernel=cfg.kernel, entry_abi=cfg.entry_abi, always_inline=cfg.always_inline) =
148+
CompilerConfig(target, params; kernel, entry_abi, always_inline)
159149

160-
function Base.show(io::IO, @nospecialize(job::CompilerJob{T})) where {T}
161-
print(io, "CompilerJob of ", job.source, " for ", T)
150+
function Base.show(io::IO, @nospecialize(cfg::CompilerConfig{T})) where {T}
151+
print(io, "CompilerConfig for ", T)
162152
end
163153

164-
function Base.hash(job::CompilerJob, h::UInt)
165-
h = hash(job.source, h)
166-
h = hash(job.target, h)
167-
h = hash(job.params, h)
154+
function Base.hash(cfg::CompilerConfig, h::UInt)
155+
h = hash(cfg.target, h)
156+
h = hash(cfg.params, h)
168157

169-
h = hash(job.entry_abi, h)
170-
h = hash(job.always_inline, h)
158+
h = hash(cfg.kernel, h)
159+
h = hash(cfg.entry_abi, h)
160+
h = hash(cfg.always_inline, h)
171161

172162
return h
173163
end
174164

175165

166+
## job
167+
168+
export CompilerJob
169+
170+
# a specific invocation of the compiler, bundling everything needed to generate code
171+
172+
struct CompilerJob{T,P}
173+
config::CompilerConfig{T,P}
174+
source::FunctionSpec
175+
176+
CompilerJob(cfg::CompilerConfig{T,P}, src::FunctionSpec) where {T,P} =
177+
new{T,P}(cfg, src)
178+
end
179+
180+
176181
## contexts
177182

178183
if VERSION >= v"1.9.0-DEV.516"
@@ -233,7 +238,7 @@ function process_entry!(@nospecialize(job::CompilerJob), mod::LLVM.Module,
233238
entry::LLVM.Function)
234239
ctx = context(mod)
235240

236-
if job.source.kernel && needs_byval(job)
241+
if job.config.kernel && needs_byval(job)
237242
# pass all bitstypes by value; by default Julia passes aggregates by reference
238243
# (this improves performance, and is mandated by certain back-ends like SPIR-V).
239244
args = classify_arguments(job, eltype(llvmtype(entry)))
@@ -279,7 +284,7 @@ valid_function_pointer(@nospecialize(job::CompilerJob), ptr::Ptr{Cvoid}) = false
279284
# the codeinfo cache to use
280285
function ci_cache(@nospecialize(job::CompilerJob))
281286
lock(GLOBAL_CI_CACHES_LOCK) do
282-
cache = get!(GLOBAL_CI_CACHES, (typeof(job.target), inference_params(job), optimization_params(job))) do
287+
cache = get!(GLOBAL_CI_CACHES, (typeof(job.config.target), inference_params(job), optimization_params(job))) do
283288
CodeCache()
284289
end
285290
return cache
@@ -302,7 +307,7 @@ function optimization_params(@nospecialize(job::CompilerJob))
302307
kwargs = (kwargs..., unoptimize_throw_blocks=false)
303308
end
304309

305-
if job.always_inline
310+
if job.config.always_inline
306311
kwargs = (kwargs..., inline_cost_threshold=typemax(Int))
307312
end
308313

0 commit comments

Comments
 (0)