Skip to content

Commit 14fadbc

Browse files
authored
Merge pull request #417 from JuliaGPU/tb/functionspec
Replace FunctionSpec with methodinstance
2 parents 94085d0 + 1a2f444 commit 14fadbc

File tree

19 files changed

+478
-493
lines changed

19 files changed

+478
-493
lines changed

examples/kernel.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ GPUCompiler.runtime_module(::CompilerJob{<:Any,TestCompilerParams}) = TestRuntim
1616
kernel() = nothing
1717

1818
function main()
19-
source = FunctionSpec(typeof(kernel), Tuple{})
19+
source = methodinstance(typeof(kernel), Tuple{})
2020
target = NativeCompilerTarget()
2121
params = TestCompilerParams()
2222
config = CompilerConfig(target, params)

src/cache.jl

Lines changed: 11 additions & 218 deletions
Original file line numberDiff line numberDiff line change
@@ -1,218 +1,13 @@
1-
# compilation cache
2-
3-
using Core.Compiler: retrieve_code_info, CodeInfo, MethodInstance, SSAValue, SlotNumber, ReturnNode
4-
using Base: _methods_by_ftype
5-
6-
# generated function that returns the world age of a compilation job. this can be used to
7-
# drive compilation, e.g. by using it as a key for a cache, as the age will change when a
8-
# function or any called function is redefined.
9-
10-
11-
"""
12-
get_world(ft, tt)
13-
14-
A special function that returns the world age in which the current definition of function
15-
type `ft`, invoked with argument types `tt`, is defined. This can be used to cache
16-
compilation results:
17-
18-
compilation_cache = Dict()
19-
function cache_compilation(ft, tt)
20-
world = get_world(ft, tt)
21-
get!(compilation_cache, (ft, tt, world)) do
22-
# compile
23-
end
24-
end
25-
26-
What makes this function special is that it is a generated function, returning a constant,
27-
whose result is automatically invalidated when the function `ft` (or any called function) is
28-
redefined. This makes this query ideally suited for hot code, where you want to avoid a
29-
costly look-up of the current world age on every invocation.
30-
31-
Normally, you shouldn't have to use this function, as it's used by `FunctionSpec`.
32-
33-
!!! warning
34-
35-
Due to a bug in Julia, JuliaLang/julia#34962, this function's results are only
36-
guaranteed to be correctly invalidated when the target function `ft` is executed or
37-
processed by codegen (e.g., by calling `code_llvm`).
38-
"""
39-
get_world
40-
41-
if VERSION >= v"1.10.0-DEV.873"
42-
43-
# on 1.10 (JuliaLang/julia#48611) the generated function knows which world it was invoked in
44-
45-
function _generated_ex(world, source, ex)
46-
stub = Core.GeneratedFunctionStub(identity, Core.svec(:get_world, :ft, :tt), Core.svec())
47-
stub(world, source, ex)
48-
end
49-
50-
function get_world_generator(world::UInt, source, self, ft::Type, tt::Type)
51-
@nospecialize
52-
@assert Core.Compiler.isType(ft) && Core.Compiler.isType(tt)
53-
ft = ft.parameters[1]
54-
tt = tt.parameters[1]
55-
56-
# look up the method
57-
method_error = :(throw(MethodError(ft, tt, $world)))
58-
Base.isdispatchtuple(tt) || return _generated_ex(world, source, :(error("$tt is not a dispatch tuple")))
59-
sig = Tuple{ft, tt.parameters...}
60-
min_world = Ref{UInt}(typemin(UInt))
61-
max_world = Ref{UInt}(typemax(UInt))
62-
has_ambig = Ptr{Int32}(C_NULL) # don't care about ambiguous results
63-
mthds = if VERSION >= v"1.7.0-DEV.1297"
64-
Base._methods_by_ftype(sig, #=mt=# nothing, #=lim=# -1,
65-
world, #=ambig=# false,
66-
min_world, max_world, has_ambig)
67-
# XXX: use the correct method table to support overlaying kernels
68-
else
69-
Base._methods_by_ftype(sig, #=lim=# -1,
70-
world, #=ambig=# false,
71-
min_world, max_world, has_ambig)
72-
end
73-
mthds === nothing && return _generated_ex(world, source, method_error)
74-
length(mthds) == 1 || return _generated_ex(world, source, method_error)
75-
76-
# look up the method and code instance
77-
mtypes, msp, m = mthds[1]
78-
mi = ccall(:jl_specializations_get_linfo, Ref{MethodInstance}, (Any, Any, Any), m, mtypes, msp)
79-
ci = retrieve_code_info(mi, world)::CodeInfo
80-
81-
# prepare a new code info
82-
new_ci = copy(ci)
83-
empty!(new_ci.code)
84-
empty!(new_ci.codelocs)
85-
resize!(new_ci.linetable, 1) # see note below
86-
empty!(new_ci.ssaflags)
87-
new_ci.ssavaluetypes = 0
88-
new_ci.min_world = min_world[]
89-
new_ci.max_world = max_world[]
90-
new_ci.edges = MethodInstance[mi]
91-
# XXX: setting this edge does not give us proper method invalidation, see
92-
# JuliaLang/julia#34962 which demonstrates we also need to "call" the kernel.
93-
# invoking `code_llvm` also does the necessary codegen, as does calling the
94-
# underlying C methods -- which GPUCompiler does, so everything Just Works.
95-
96-
# prepare the slots
97-
new_ci.slotnames = Symbol[Symbol("#self#"), :ft, :tt]
98-
new_ci.slotflags = UInt8[0x00 for i = 1:3]
99-
100-
# return the world
101-
push!(new_ci.code, ReturnNode(world))
102-
push!(new_ci.ssaflags, 0x00) # Julia's native compilation pipeline (and its verifier) expects `ssaflags` to be the same length as `code`
103-
push!(new_ci.codelocs, 1) # see note below
104-
new_ci.ssavaluetypes += 1
105-
106-
# NOTE: we keep the first entry of the original linetable, and use it for location info
107-
# on the call to check_cache. we can't not have a codeloc (using 0 causes
108-
# corruption of the back trace), and reusing the target function's info
109-
# has as advantage that we see the name of the kernel in the backtraces.
110-
111-
return new_ci
112-
end
113-
114-
@eval function get_world(ft, tt)
115-
$(Expr(:meta, :generated_only))
116-
$(Expr(:meta, :generated, get_world_generator))
117-
end
118-
119-
else
120-
121-
# on older versions of Julia we fall back to looking up the current world. this may be wrong
122-
# when the generator is invoked in a different world (TODO: when does this happen?)
123-
124-
function get_world_generator(self, ft::Type, tt::Type)
125-
@nospecialize
126-
@assert Core.Compiler.isType(ft) && Core.Compiler.isType(tt)
127-
ft = ft.parameters[1]
128-
tt = tt.parameters[1]
129-
130-
# look up the method
131-
method_error = :(throw(MethodError(ft, tt)))
132-
Base.isdispatchtuple(tt) || return(:(error("$tt is not a dispatch tuple")))
133-
sig = Tuple{ft, tt.parameters...}
134-
min_world = Ref{UInt}(typemin(UInt))
135-
max_world = Ref{UInt}(typemax(UInt))
136-
has_ambig = Ptr{Int32}(C_NULL) # don't care about ambiguous results
137-
mthds = if VERSION >= v"1.7.0-DEV.1297"
138-
Base._methods_by_ftype(sig, #=mt=# nothing, #=lim=# -1,
139-
#=world=# typemax(UInt), #=ambig=# false,
140-
min_world, max_world, has_ambig)
141-
# XXX: use the correct method table to support overlaying kernels
142-
else
143-
Base._methods_by_ftype(sig, #=lim=# -1,
144-
#=world=# typemax(UInt), #=ambig=# false,
145-
min_world, max_world, has_ambig)
146-
end
147-
# XXX: using world=-1 is wrong, but the current world isn't exposed to this generator
148-
mthds === nothing && return method_error
149-
length(mthds) == 1 || return method_error
150-
151-
# look up the method and code instance
152-
mtypes, msp, m = mthds[1]
153-
mi = ccall(:jl_specializations_get_linfo, Ref{MethodInstance}, (Any, Any, Any), m, mtypes, msp)
154-
ci = retrieve_code_info(mi)::CodeInfo
155-
156-
# XXX: we don't know the world age that this generator was requested to run in, so use
157-
# the current world (we cannot use the mi's world because that doesn't update when
158-
# called functions are changed). this isn't correct, but should be close.
159-
world = Base.get_world_counter()
160-
161-
# prepare a new code info
162-
new_ci = copy(ci)
163-
empty!(new_ci.code)
164-
empty!(new_ci.codelocs)
165-
resize!(new_ci.linetable, 1) # see note below
166-
empty!(new_ci.ssaflags)
167-
new_ci.ssavaluetypes = 0
168-
new_ci.min_world = min_world[]
169-
new_ci.max_world = max_world[]
170-
new_ci.edges = MethodInstance[mi]
171-
# XXX: setting this edge does not give us proper method invalidation, see
172-
# JuliaLang/julia#34962 which demonstrates we also need to "call" the kernel.
173-
# invoking `code_llvm` also does the necessary codegen, as does calling the
174-
# underlying C methods -- which GPUCompiler does, so everything Just Works.
175-
176-
# prepare the slots
177-
new_ci.slotnames = Symbol[Symbol("#self#"), :ft, :tt]
178-
new_ci.slotflags = UInt8[0x00 for i = 1:3]
179-
180-
# return the world
181-
push!(new_ci.code, ReturnNode(world))
182-
push!(new_ci.ssaflags, 0x00) # Julia's native compilation pipeline (and its verifier) expects `ssaflags` to be the same length as `code`
183-
push!(new_ci.codelocs, 1) # see note below
184-
new_ci.ssavaluetypes += 1
185-
186-
# NOTE: we keep the first entry of the original linetable, and use it for location info
187-
# on the call to check_cache. we can't not have a codeloc (using 0 causes
188-
# corruption of the back trace), and reusing the target function's info
189-
# has as advantage that we see the name of the kernel in the backtraces.
190-
191-
return new_ci
192-
end
193-
194-
@eval function get_world(ft, tt)
195-
$(Expr(:meta, :generated_only))
196-
$(Expr(:meta,
197-
:generated,
198-
Expr(:new,
199-
Core.GeneratedFunctionStub,
200-
:get_world_generator,
201-
Any[:get_world, :ft, :tt],
202-
Any[],
203-
@__LINE__,
204-
QuoteNode(Symbol(@__FILE__)),
205-
true)))
206-
end
207-
208-
end
1+
# cached compilation
2092

2103
const cache_lock = ReentrantLock()
2114

2125
"""
213-
cached_compilation(cache::Dict{UInt}, job::CompilerJob, compiler, linker)
6+
cached_compilation(cache::Dict{UInt}, cfg::CompilerConfig, ft::Type, tt::Type,
7+
compiler, linker)
2148
215-
Compile `job` using `compiler` and `linker`, and store the result in `cache`.
9+
Compile a method instance, identified by its function type `ft` and argument types `tt`,
10+
using `compiler` and `linker`, and store the result in `cache`.
21611
21712
The `cache` argument should be a dictionary that can be indexed using a `UInt` and store
21813
whatever the `linker` function returns. The `compiler` function should take a `CompilerJob`
@@ -224,10 +19,9 @@ function cached_compilation(cache::AbstractDict{UInt,V},
22419
cfg::CompilerConfig,
22520
ft::Type, tt::Type,
22621
compiler::Function, linker::Function) where {V}
227-
# NOTE: it is OK to index the compilation cache directly with the world age, instead of
228-
# intersecting world age ranges, because we the world age is aquired by calling
229-
# `get_world` and thus will only change when the kernel function is redefined.
230-
world = get_world(ft, tt)
22+
# NOTE: we only use the codegen world age for invalidation purposes;
23+
# actual compilation happens at the current world age.
24+
world = codegen_world_age(ft, tt)
23125
key = hash(ft)
23226
key = hash(tt, key)
23327
key = hash(world, key)
@@ -240,16 +34,15 @@ function cached_compilation(cache::AbstractDict{UInt,V},
24034

24135
LLVM.Interop.assume(isassigned(compile_hook))
24236
if obj === nothing || compile_hook[] !== nothing
243-
obj = actual_compilation(cache, key, cfg, ft, tt, world, compiler, linker)::V
37+
obj = actual_compilation(cache, key, cfg, ft, tt, compiler, linker)::V
24438
end
24539
return obj::V
24640
end
24741

24842
@noinline function actual_compilation(cache::AbstractDict, key::UInt,
249-
cfg::CompilerConfig,
250-
ft::Type, tt::Type, world,
43+
cfg::CompilerConfig, ft::Type, tt::Type,
25144
compiler::Function, linker::Function)
252-
src = FunctionSpec(ft, tt, world)
45+
src = methodinstance(ft, tt)
25346
job = CompilerJob(src, cfg)
25447

25548
asm = nothing

0 commit comments

Comments
 (0)