3
3
using Core. Compiler: retrieve_code_info, CodeInfo, MethodInstance, SSAValue, SlotNumber, ReturnNode
4
4
using Base: _methods_by_ftype
5
5
6
- # generated function that crafts a custom code info to call the actual compiler.
7
- # this gives us the flexibility to insert manual back edges for automatic recompilation.
6
+ # generated function that returns the world age of a compilation job. this can be used to
7
+ # drive compilation, e.g. by using it as a key for a cache, as the age will change when a
8
+ # function or any called function is redefined.
9
+
10
+
11
+ """
12
+ get_world(ft, tt)
13
+
14
+ A special function that returns the world age in which the current definition of function
15
+ type `ft`, invoked with argument types `tt`, is defined. This can be used to cache
16
+ compilation results:
17
+
18
+ compilation_cache = Dict()
19
+ function cache_compilation(ft, tt)
20
+ world = get_world(ft, tt)
21
+ get!(compilation_cache, (ft, tt, world)) do
22
+ # compile
23
+ end
24
+ end
25
+
26
+ What makes this function special is that it is a generated function, returning a constant,
27
+ whose result is automatically invalidated when the function `ft` (or any called function) is
28
+ redefined. This makes this query ideally suited for hot code, where you want to avoid a
29
+ costly look-up of the current world age on every invocation.
30
+
31
+ Normally, you shouldn't have to use this function, as it's used by `FunctionSpec`.
32
+
33
+ !!! warning
34
+
35
+ Due to a bug in Julia, JuliaLang/julia#34962, this function's results are only
36
+ guaranteed to be correctly invalidated when the target function `ft` is executed or
37
+ processed by codegen (e.g., by calling `code_llvm`).
38
+ """
39
+ get_world
40
+
41
+ # generate functions currently do not know which world they are invoked for, so we fall
42
+ # back to using the current world. this may be wrong when the generator is invoked in a
43
+ # different world (TODO : when does this happen?)
8
44
#
9
- # we also increment a global specialization counter and pass it along to index the cache.
10
-
11
- const specialization_counter = Ref {UInt} (0 )
12
- @generated function specialization_id (job:: CompilerJob{<:Any,<:Any,FunctionSpec{f,tt}} ) where {f,tt}
13
- # get a hold of the method and code info of the kernel function
14
- sig = Tuple{f, tt. parameters... }
15
- # XXX : instead of typemax(UInt) we should use the world-age of the fspec
16
- mthds = _methods_by_ftype (sig, - 1 , typemax (UInt))
45
+ # XXX : this should be fixed by JuliaLang/julia#48611
46
+
47
+ function get_world_generator (self, :: Type{Type{ft}} , :: Type{Type{tt}} ) where {ft, tt}
48
+ @nospecialize
49
+
50
+ # look up the method
51
+ sig = Tuple{ft, tt. parameters... }
52
+ min_world = Ref {UInt} (typemin (UInt))
53
+ max_world = Ref {UInt} (typemax (UInt))
54
+ has_ambig = Ptr {Int32} (C_NULL ) # don't care about ambiguous results
55
+ mthds = if VERSION >= v " 1.7.0-DEV.1297"
56
+ Base. _methods_by_ftype (sig, #= mt=# nothing , #= lim=# - 1 ,
57
+ #= world=# typemax (UInt), #= ambig=# false ,
58
+ min_world, max_world, has_ambig)
59
+ # XXX : use the correct method table to support overlaying kernels
60
+ else
61
+ Base. _methods_by_ftype (sig, #= lim=# - 1 ,
62
+ #= world=# typemax (UInt), #= ambig=# false ,
63
+ min_world, max_world, has_ambig)
64
+ end
65
+ # XXX : using world=-1 is wrong, but the current world isn't exposed to this generator
66
+
67
+ # check the validity of the method matches
68
+ method_error = :(throw (MethodError (ft, tt)))
69
+ mthds === nothing && return method_error
17
70
Base. isdispatchtuple (tt) || return (:(error (" $tt is not a dispatch tuple" )))
18
- length (mthds) == 1 || return (:(throw (MethodError (job. source. f,job. source. tt))))
71
+ length (mthds) == 1 || return method_error
72
+
73
+ # look up the method and code instance
19
74
mtypes, msp, m = mthds[1 ]
20
75
mi = ccall (:jl_specializations_get_linfo , Ref{MethodInstance}, (Any, Any, Any), m, mtypes, msp)
21
76
ci = retrieve_code_info (mi):: CodeInfo
22
77
23
- # generate a unique id to represent this specialization
24
- # TODO : just use the lower world age bound in which this code info is valid.
25
- # (the method instance doesn't change when called functions are changed).
26
- # but how to get that? the ci here always has min/max world 1/-1.
27
- # XXX : don't use `objectid(ci)` here, apparently it can alias (or the CI doesn't change?)
28
- id = (specialization_counter[] += 1 )
78
+ # XXX : we don't know the world age that this generator was requested to run in, so use
79
+ # the current world (we cannot use the mi's world because that doesn't update when
80
+ # called functions are changed). this isn't correct, but should be close.
81
+ world = Base. get_world_counter ()
29
82
30
83
# prepare a new code info
31
84
new_ci = copy (ci)
@@ -34,22 +87,20 @@ const specialization_counter = Ref{UInt}(0)
34
87
resize! (new_ci. linetable, 1 ) # see note below
35
88
empty! (new_ci. ssaflags)
36
89
new_ci. ssavaluetypes = 0
90
+ new_ci. min_world = min_world[]
91
+ new_ci. max_world = max_world[]
37
92
new_ci. edges = MethodInstance[mi]
38
93
# XXX : setting this edge does not give us proper method invalidation, see
39
94
# JuliaLang/julia#34962 which demonstrates we also need to "call" the kernel.
40
95
# invoking `code_llvm` also does the necessary codegen, as does calling the
41
96
# underlying C methods -- which GPUCompiler does, so everything Just Works.
42
97
43
98
# prepare the slots
44
- new_ci. slotnames = Symbol[Symbol (" #self#" ), :cache , :job , :compiler , :linker ]
45
- new_ci. slotflags = UInt8[0x00 for i = 1 : 5 ]
46
- cache = SlotNumber (2 )
47
- job = SlotNumber (3 )
48
- compiler = SlotNumber (4 )
49
- linker = SlotNumber (5 )
50
-
51
- # call the compiler
52
- push! (new_ci. code, ReturnNode (id))
99
+ new_ci. slotnames = Symbol[Symbol (" #self#" ), :ft , :tt ]
100
+ new_ci. slotflags = UInt8[0x00 for i = 1 : 3 ]
101
+
102
+ # return the world
103
+ push! (new_ci. code, ReturnNode (world))
53
104
push! (new_ci. ssaflags, 0x00 ) # Julia's native compilation pipeline (and its verifier) expects `ssaflags` to be the same length as `code`
54
105
push! (new_ci. codelocs, 1 ) # see note below
55
106
new_ci. ssavaluetypes += 1
@@ -62,17 +113,48 @@ const specialization_counter = Ref{UInt}(0)
62
113
return new_ci
63
114
end
64
115
116
+ @eval function get_world (ft, tt)
117
+ $ (Expr (:meta , :generated_only ))
118
+ $ (Expr (:meta ,
119
+ :generated ,
120
+ Expr (:new ,
121
+ Core. GeneratedFunctionStub,
122
+ :get_world_generator ,
123
+ Any[:get_world , :ft , :tt ],
124
+ Any[],
125
+ @__LINE__ ,
126
+ QuoteNode (Symbol (@__FILE__ )),
127
+ true )))
128
+ end
129
+
65
130
const cache_lock = ReentrantLock ()
131
+
132
+ """
133
+ cached_compilation(cache::Dict, job::CompilerJob, compiler, linker)
134
+
135
+ Compile `job` using `compiler` and `linker`, and store the result in `cache`.
136
+
137
+ The `cache` argument should be a dictionary that can be indexed using a `UInt` and store
138
+ whatever the `linker` function returns. The `compiler` function should take a `CompilerJob`
139
+ and return data that can be cached across sessions (e.g., LLVM IR). This data is then
140
+ forwarded, along with the `CompilerJob`, to the `linker` function which is allowed to create
141
+ session-dependent objects (e.g., a `CuModule`).
142
+ """
66
143
function cached_compilation (cache:: AbstractDict ,
67
144
@nospecialize (job:: CompilerJob ),
68
145
compiler:: Function , linker:: Function )
69
- # XXX : CompilerJob contains a world age, so can't be respecialized.
70
- # have specialization_id take a f/tt and return a world to construct a CompilerJob?
71
- key = hash (job, specialization_id (job))
72
- force_compilation = compile_hook[] != = nothing
146
+ # NOTE: it is OK to index the compilation cache directly with the compilation job, i.e.,
147
+ # using a world age instead of intersecting world age ranges, because we expect
148
+ # that the world age is aquired through calling `get_world` and thus will only
149
+ # ever change when the kernel function is redefined.
150
+ #
151
+ # if we ever want to be able to index the cache using a compilation job that
152
+ # contains a more recent world age, yet still return an older cached object that
153
+ # would still be valid, we'd need the cache to store world ranges instead and
154
+ # use an invalidation callback to add upper bounds to entries.
155
+ key = hash (job)
73
156
74
- # XXX : by taking the hash, we index the compilation cache directly with the world age.
75
- # that's wrong; we should perform an intersection with the entry its bounds.
157
+ force_compilation = compile_hook[] != = nothing
76
158
77
159
# NOTE: no use of lock(::Function)/@lock/get! to keep stack traces clean
78
160
lock (cache_lock)
0 commit comments