Skip to content

Commit 860ec6a

Browse files
authored
Taking world ages seriously (#394)
1 parent 4e98899 commit 860ec6a

File tree

19 files changed

+516
-297
lines changed

19 files changed

+516
-297
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "GPUCompiler"
22
uuid = "61eb1bfa-7361-4325-ad38-22787b887f55"
33
authors = ["Tim Besard <tim.besard@gmail.com>"]
4-
version = "0.17.3"
4+
version = "0.18.0"
55

66
[deps]
77
ExprTools = "e2ba6199-217a-4e67-a87a-7c52f15ade04"

examples/kernel.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,10 @@ GPUCompiler.runtime_module(::CompilerJob{<:Any,TestCompilerParams}) = TestRuntim
1616
kernel() = nothing
1717

1818
function main()
19-
source = FunctionSpec(typeof(kernel))
19+
source = FunctionSpec(typeof(kernel), Tuple{})
2020
target = NativeCompilerTarget()
2121
params = TestCompilerParams()
22-
job = CompilerJob(target, source, params)
22+
job = CompilerJob(source, target, params)
2323

2424
println(GPUCompiler.compile(:asm, job)[1])
2525
end

src/cache.jl

Lines changed: 114 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -3,29 +3,82 @@
33
using Core.Compiler: retrieve_code_info, CodeInfo, MethodInstance, SSAValue, SlotNumber, ReturnNode
44
using Base: _methods_by_ftype
55

6-
# generated function that crafts a custom code info to call the actual compiler.
7-
# this gives us the flexibility to insert manual back edges for automatic recompilation.
6+
# generated function that returns the world age of a compilation job. this can be used to
7+
# drive compilation, e.g. by using it as a key for a cache, as the age will change when a
8+
# function or any called function is redefined.
9+
10+
11+
"""
12+
get_world(ft, tt)
13+
14+
A special function that returns the world age in which the current definition of function
15+
type `ft`, invoked with argument types `tt`, is defined. This can be used to cache
16+
compilation results:
17+
18+
compilation_cache = Dict()
19+
function cache_compilation(ft, tt)
20+
world = get_world(ft, tt)
21+
get!(compilation_cache, (ft, tt, world)) do
22+
# compile
23+
end
24+
end
25+
26+
What makes this function special is that it is a generated function, returning a constant,
27+
whose result is automatically invalidated when the function `ft` (or any called function) is
28+
redefined. This makes this query ideally suited for hot code, where you want to avoid a
29+
costly look-up of the current world age on every invocation.
30+
31+
Normally, you shouldn't have to use this function, as it's used by `FunctionSpec`.
32+
33+
!!! warning
34+
35+
Due to a bug in Julia, JuliaLang/julia#34962, this function's results are only
36+
guaranteed to be correctly invalidated when the target function `ft` is executed or
37+
processed by codegen (e.g., by calling `code_llvm`).
38+
"""
39+
get_world
40+
41+
# generate functions currently do not know which world they are invoked for, so we fall
42+
# back to using the current world. this may be wrong when the generator is invoked in a
43+
# different world (TODO: when does this happen?)
844
#
9-
# we also increment a global specialization counter and pass it along to index the cache.
10-
11-
const specialization_counter = Ref{UInt}(0)
12-
@generated function specialization_id(job::CompilerJob{<:Any,<:Any,FunctionSpec{f,tt}}) where {f,tt}
13-
# get a hold of the method and code info of the kernel function
14-
sig = Tuple{f, tt.parameters...}
15-
# XXX: instead of typemax(UInt) we should use the world-age of the fspec
16-
mthds = _methods_by_ftype(sig, -1, typemax(UInt))
45+
# XXX: this should be fixed by JuliaLang/julia#48611
46+
47+
function get_world_generator(self, ::Type{Type{ft}}, ::Type{Type{tt}}) where {ft, tt}
48+
@nospecialize
49+
50+
# look up the method
51+
sig = Tuple{ft, tt.parameters...}
52+
min_world = Ref{UInt}(typemin(UInt))
53+
max_world = Ref{UInt}(typemax(UInt))
54+
has_ambig = Ptr{Int32}(C_NULL) # don't care about ambiguous results
55+
mthds = if VERSION >= v"1.7.0-DEV.1297"
56+
Base._methods_by_ftype(sig, #=mt=# nothing, #=lim=# -1,
57+
#=world=# typemax(UInt), #=ambig=# false,
58+
min_world, max_world, has_ambig)
59+
# XXX: use the correct method table to support overlaying kernels
60+
else
61+
Base._methods_by_ftype(sig, #=lim=# -1,
62+
#=world=# typemax(UInt), #=ambig=# false,
63+
min_world, max_world, has_ambig)
64+
end
65+
# XXX: using world=-1 is wrong, but the current world isn't exposed to this generator
66+
67+
# check the validity of the method matches
68+
method_error = :(throw(MethodError(ft, tt)))
69+
mthds === nothing && return method_error
1770
Base.isdispatchtuple(tt) || return(:(error("$tt is not a dispatch tuple")))
18-
length(mthds) == 1 || return (:(throw(MethodError(job.source.f,job.source.tt))))
71+
length(mthds) == 1 || return method_error
72+
73+
# look up the method and code instance
1974
mtypes, msp, m = mthds[1]
2075
mi = ccall(:jl_specializations_get_linfo, Ref{MethodInstance}, (Any, Any, Any), m, mtypes, msp)
2176
ci = retrieve_code_info(mi)::CodeInfo
2277

23-
# generate a unique id to represent this specialization
24-
# TODO: just use the lower world age bound in which this code info is valid.
25-
# (the method instance doesn't change when called functions are changed).
26-
# but how to get that? the ci here always has min/max world 1/-1.
27-
# XXX: don't use `objectid(ci)` here, apparently it can alias (or the CI doesn't change?)
28-
id = (specialization_counter[] += 1)
78+
# XXX: we don't know the world age that this generator was requested to run in, so use
79+
# the current world (we cannot use the mi's world because that doesn't update when
80+
# called functions are changed). this isn't correct, but should be close.
81+
world = Base.get_world_counter()
2982

3083
# prepare a new code info
3184
new_ci = copy(ci)
@@ -34,22 +87,20 @@ const specialization_counter = Ref{UInt}(0)
3487
resize!(new_ci.linetable, 1) # see note below
3588
empty!(new_ci.ssaflags)
3689
new_ci.ssavaluetypes = 0
90+
new_ci.min_world = min_world[]
91+
new_ci.max_world = max_world[]
3792
new_ci.edges = MethodInstance[mi]
3893
# XXX: setting this edge does not give us proper method invalidation, see
3994
# JuliaLang/julia#34962 which demonstrates we also need to "call" the kernel.
4095
# invoking `code_llvm` also does the necessary codegen, as does calling the
4196
# underlying C methods -- which GPUCompiler does, so everything Just Works.
4297

4398
# prepare the slots
44-
new_ci.slotnames = Symbol[Symbol("#self#"), :cache, :job, :compiler, :linker]
45-
new_ci.slotflags = UInt8[0x00 for i = 1:5]
46-
cache = SlotNumber(2)
47-
job = SlotNumber(3)
48-
compiler = SlotNumber(4)
49-
linker = SlotNumber(5)
50-
51-
# call the compiler
52-
push!(new_ci.code, ReturnNode(id))
99+
new_ci.slotnames = Symbol[Symbol("#self#"), :ft, :tt]
100+
new_ci.slotflags = UInt8[0x00 for i = 1:3]
101+
102+
# return the world
103+
push!(new_ci.code, ReturnNode(world))
53104
push!(new_ci.ssaflags, 0x00) # Julia's native compilation pipeline (and its verifier) expects `ssaflags` to be the same length as `code`
54105
push!(new_ci.codelocs, 1) # see note below
55106
new_ci.ssavaluetypes += 1
@@ -62,17 +113,48 @@ const specialization_counter = Ref{UInt}(0)
62113
return new_ci
63114
end
64115

116+
@eval function get_world(ft, tt)
117+
$(Expr(:meta, :generated_only))
118+
$(Expr(:meta,
119+
:generated,
120+
Expr(:new,
121+
Core.GeneratedFunctionStub,
122+
:get_world_generator,
123+
Any[:get_world, :ft, :tt],
124+
Any[],
125+
@__LINE__,
126+
QuoteNode(Symbol(@__FILE__)),
127+
true)))
128+
end
129+
65130
const cache_lock = ReentrantLock()
131+
132+
"""
133+
cached_compilation(cache::Dict, job::CompilerJob, compiler, linker)
134+
135+
Compile `job` using `compiler` and `linker`, and store the result in `cache`.
136+
137+
The `cache` argument should be a dictionary that can be indexed using a `UInt` and store
138+
whatever the `linker` function returns. The `compiler` function should take a `CompilerJob`
139+
and return data that can be cached across sessions (e.g., LLVM IR). This data is then
140+
forwarded, along with the `CompilerJob`, to the `linker` function which is allowed to create
141+
session-dependent objects (e.g., a `CuModule`).
142+
"""
66143
function cached_compilation(cache::AbstractDict,
67144
@nospecialize(job::CompilerJob),
68145
compiler::Function, linker::Function)
69-
# XXX: CompilerJob contains a world age, so can't be respecialized.
70-
# have specialization_id take a f/tt and return a world to construct a CompilerJob?
71-
key = hash(job, specialization_id(job))
72-
force_compilation = compile_hook[] !== nothing
146+
# NOTE: it is OK to index the compilation cache directly with the compilation job, i.e.,
147+
# using a world age instead of intersecting world age ranges, because we expect
148+
# that the world age is aquired through calling `get_world` and thus will only
149+
# ever change when the kernel function is redefined.
150+
#
151+
# if we ever want to be able to index the cache using a compilation job that
152+
# contains a more recent world age, yet still return an older cached object that
153+
# would still be valid, we'd need the cache to store world ranges instead and
154+
# use an invalidation callback to add upper bounds to entries.
155+
key = hash(job)
73156

74-
# XXX: by taking the hash, we index the compilation cache directly with the world age.
75-
# that's wrong; we should perform an intersection with the entry its bounds.
157+
force_compilation = compile_hook[] !== nothing
76158

77159
# NOTE: no use of lock(::Function)/@lock/get! to keep stack traces clean
78160
lock(cache_lock)

src/driver.jl

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,13 @@ end
158158

159159
# get the method instance
160160
sig = typed_signature(job)
161-
meth = which(sig)
161+
meth = if VERSION >= v"1.10.0-DEV.65"
162+
Base._which(sig; world=job.source.world).method
163+
elseif VERSION >= v"1.7.0-DEV.435"
164+
Base._which(sig, job.source.world).method
165+
else
166+
ccall(:jl_gf_invoke_lookup, Any, (Any, UInt), sig, job.source.world)
167+
end
162168

163169
(ti, env) = ccall(:jl_type_intersection_with_env, Any,
164170
(Any, Any), sig, meth.sig)::Core.SimpleVector
@@ -175,6 +181,10 @@ end
175181
end
176182
end
177183

184+
# ensure that the returned method instance is valid in the compilation world.
185+
# otherwise, `jl_create_native` won't actually emit any code.
186+
@assert method_instance.def.primary_world <= job.source.world <= method_instance.def.deleted_world
187+
178188
return method_instance, ()
179189
end
180190

@@ -189,9 +199,9 @@ Base.@ccallable Ptr{Cvoid} function deferred_codegen(ptr::Ptr{Cvoid})
189199
ptr
190200
end
191201

192-
@generated function deferred_codegen(::Val{f}, ::Val{tt}) where {f,tt}
202+
@generated function deferred_codegen(::Val{ft}, ::Val{tt}) where {ft,tt}
193203
id = length(deferred_codegen_jobs) + 1
194-
deferred_codegen_jobs[id] = FunctionSpec(f,tt)
204+
deferred_codegen_jobs[id] = FunctionSpec(ft, tt)
195205

196206
pseudo_ptr = reinterpret(Ptr{Cvoid}, id)
197207
quote
@@ -286,10 +296,19 @@ const __llvm_initialized = Ref(false)
286296
id = convert(Int, first(operands(call)))
287297

288298
global deferred_codegen_jobs
289-
dyn_job = deferred_codegen_jobs[id]
290-
if dyn_job isa FunctionSpec
291-
dyn_job = similar(job, dyn_job)
299+
dyn_val = deferred_codegen_jobs[id]
300+
301+
# get a job in the appopriate world
302+
dyn_job = if dyn_val isa CompilerJob
303+
dyn_spec = FunctionSpec(dyn_val.source; world=job.source.world)
304+
CompilerJob(dyn_val; source=dyn_spec)
305+
elseif dyn_val isa FunctionSpec
306+
dyn_spec = FunctionSpec(dyn_val; world=job.source.world)
307+
CompilerJob(job; source=dyn_spec)
308+
else
309+
error("invalid deferred job type $(typeof(dyn_val))")
292310
end
311+
293312
push!(get!(worklist, dyn_job, LLVM.CallInst[]), call)
294313
end
295314

0 commit comments

Comments
 (0)