Skip to content

Commit 9738e4b

Browse files
Kenomaleadtvtjnash
authored
Refactor cache logic for easy replacement (#35831)
* Refactor cache logic for easy replacement This is the next step in the line of work started by #33955, though a lot of enabling work towards this was previously done by Jameson in his codegen-norecursion branch. The basic thrust here is to allow external packages to manage their own cache of compiled code that may have been generated using entirely difference inference or compiler options. The GPU compilers are one such example, but there are several others, including generating code using offload compilers, such as XLA or compilers for secure computation. A lot of this is just moving code arround to make it clear exactly which parts of the code are accessing the internal code cache (which is now its own type to make it obvious when it's being accessed), as well as providing clear extension points for custom cache implementations. The second part is to refactor CodeInstance construction to separate construction and insertion into the internal cache (so it can be inserted into an external cache instead if desired). The last part of the change is to give cgparams another hook that lets the caller replace the cache lookup to be used by codegen. * Update base/compiler/cicache.jl Co-authored-by: Tim Besard <tim.besard@gmail.com> * Apply suggestions from code review Co-authored-by: Jameson Nash <vtjnash@gmail.com> * Rename always_cache_tree -> !allow_discard_tree Co-authored-by: Tim Besard <tim.besard@gmail.com> Co-authored-by: Jameson Nash <vtjnash@gmail.com>
1 parent d765e59 commit 9738e4b

File tree

15 files changed

+227
-100
lines changed

15 files changed

+227
-100
lines changed

base/boot.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -381,6 +381,11 @@ eval(Core, :(UpsilonNode(val) = $(Expr(:new, :UpsilonNode, :val))))
381381
eval(Core, :(UpsilonNode() = $(Expr(:new, :UpsilonNode))))
382382
eval(Core, :(LineInfoNode(@nospecialize(method), file::Symbol, line::Int, inlined_at::Int) =
383383
$(Expr(:new, :LineInfoNode, :method, :file, :line, :inlined_at))))
384+
eval(Core, :(CodeInstance(mi::MethodInstance, @nospecialize(rettype), @nospecialize(inferred_const),
385+
@nospecialize(inferred), const_flags::Int32,
386+
min_world::UInt, max_world::UInt) =
387+
ccall(:jl_new_codeinst, Ref{CodeInstance}, (Any, Any, Any, Any, Int32, UInt, UInt),
388+
mi, rettype, inferred_const, inferred, const_flags, min_world, max_world)))
384389

385390
Module(name::Symbol=:anonymous, std_imports::Bool=true) = ccall(:jl_f_new_module, Ref{Module}, (Any, Bool), name, std_imports)
386391

base/compiler/abstractinterpretation.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -241,7 +241,7 @@ function abstract_call_method_with_const_args(interp::AbstractInterpreter, @nosp
241241
mi = mi::MethodInstance
242242
# decide if it's likely to be worthwhile
243243
if !force_inference
244-
code = inf_for_methodinstance(interp, mi, get_world_counter(interp))
244+
code = get(code_cache(interp), mi, nothing)
245245
declared_inline = isdefined(method, :source) && ccall(:jl_ir_flag_inlineable, Bool, (Any,), method.source)
246246
cache_inlineable = declared_inline
247247
if isdefined(code, :inferred) && !cache_inlineable

base/compiler/cicache.jl

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
"""
2+
struct InternalCodeCache
3+
4+
Internally, each `MethodInstance` keep a unique global cache of code instances
5+
that have been created for the given method instance, stratified by world age
6+
ranges. This struct abstracts over access to this cache.
7+
"""
8+
struct InternalCodeCache
9+
end
10+
11+
function setindex!(cache::InternalCodeCache, ci::CodeInstance, mi::MethodInstance)
12+
ccall(:jl_mi_cache_insert, Cvoid, (Any, Any), mi, ci)
13+
end
14+
15+
const GLOBAL_CI_CACHE = InternalCodeCache()
16+
17+
"""
18+
struct WorldView
19+
20+
Takes a given cache and provides access to the cache contents for the given
21+
range of world ages, rather than defaulting to the current active world age.
22+
"""
23+
struct WorldView{Cache}
24+
cache::Cache
25+
min_world::UInt
26+
max_world::UInt
27+
end
28+
WorldView(cache, r::UnitRange) = WorldView(cache, first(r), last(r))
29+
WorldView(cache, world::UInt) = WorldView(cache, world, world)
30+
WorldView(wvc::WorldView, min_world::UInt, max_world::UInt) =
31+
WorldView(wvc.cache, min_world, max_world)
32+
33+
function haskey(wvc::WorldView{InternalCodeCache}, mi::MethodInstance)
34+
ccall(:jl_rettype_inferred, Any, (Any, UInt, UInt), mi, wvc.min_world, wvc.max_world)::Union{Nothing, CodeInstance} !== nothing
35+
end
36+
37+
function get(wvc::WorldView{InternalCodeCache}, mi::MethodInstance, default)
38+
r = ccall(:jl_rettype_inferred, Any, (Any, UInt, UInt), mi, wvc.min_world, wvc.max_world)::Union{Nothing, CodeInstance}
39+
if r === nothing
40+
return default
41+
end
42+
return r::CodeInstance
43+
end
44+
45+
function getindex(wvc::WorldView{InternalCodeCache}, mi::MethodInstance)
46+
r = get(wvc, mi, nothing)
47+
r === nothing && throw(KeyError(mi))
48+
return r::CodeInstance
49+
end
50+
51+
setindex!(wvc::WorldView{InternalCodeCache}, ci::CodeInstance, mi::MethodInstance) =
52+
setindex!(wvc.cache, ci, mi)

base/compiler/compiler.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,7 @@ include("compiler/validation.jl")
100100

101101
include("compiler/inferenceresult.jl")
102102
include("compiler/inferencestate.jl")
103+
include("compiler/cicache.jl")
103104

104105
include("compiler/typeutils.jl")
105106
include("compiler/typelimits.jl")

base/compiler/ssair/inlining.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -798,7 +798,7 @@ function iterate(split::UnionSplitSignature, state::Vector{Int}...)
798798
return (sig, state)
799799
end
800800

801-
function handle_single_case!(ir::IRCode, stmt::Expr, idx::Int, @nospecialize(case), isinvoke::Bool, todo::Vector{Any}, sv::OptimizationState)
801+
function handle_single_case!(ir::IRCode, stmt::Expr, idx::Int, @nospecialize(case), isinvoke::Bool, todo::Vector{Any})
802802
if isa(case, ConstantCase)
803803
ir[SSAValue(idx)] = case.val
804804
elseif isa(case, MethodInstance)
@@ -949,7 +949,7 @@ function inline_invoke!(ir::IRCode, idx::Int, sig::Signature, invoke_data::Invok
949949
methsp = methsp::SimpleVector
950950
result = analyze_method!(idx, sig, metharg, methsp, method, stmt, sv, true, invoke_data,
951951
calltype)
952-
handle_single_case!(ir, stmt, idx, result, true, todo, sv)
952+
handle_single_case!(ir, stmt, idx, result, true, todo)
953953
update_valid_age!(invoke_data.min_valid, invoke_data.max_valid, sv)
954954
return nothing
955955
end
@@ -1117,7 +1117,7 @@ function assemble_inline_todo!(ir::IRCode, sv::OptimizationState)
11171117
# be able to do the inlining now (for constant cases), or push it directly
11181118
# onto the todo list
11191119
if fully_covered && length(cases) == 1
1120-
handle_single_case!(ir, stmt, idx, cases[1][2], false, todo, sv)
1120+
handle_single_case!(ir, stmt, idx, cases[1][2], false, todo)
11211121
continue
11221122
end
11231123
length(cases) == 0 && continue
@@ -1332,7 +1332,7 @@ function find_inferred(mi::MethodInstance, @nospecialize(atypes), sv::Optimizati
13321332
end
13331333
end
13341334

1335-
linfo = inf_for_methodinstance(sv.interp, mi, sv.world)
1335+
linfo = get(WorldView(code_cache(sv.interp), sv.world), mi, nothing)
13361336
if linfo isa CodeInstance
13371337
if invoke_api(linfo) == 2
13381338
# in this case function can be inlined to a constant

base/compiler/typeinfer.jl

Lines changed: 71 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
function typeinf(interp::AbstractInterpreter, result::InferenceResult, cached::Bool)
55
frame = InferenceState(result, cached, interp)
66
frame === nothing && return false
7-
cached && (result.linfo.inInference = true)
7+
cached && lock_mi_inference(interp, result.linfo)
88
return typeinf(interp, frame)
99
end
1010

@@ -64,7 +64,7 @@ function typeinf(interp::AbstractInterpreter, frame::InferenceState)
6464
caller.src.min_world = min_valid
6565
caller.src.max_world = max_valid
6666
if cached
67-
cache_result(interp, caller.result, min_valid, max_valid)
67+
cache_result!(interp, caller.result, min_valid, max_valid)
6868
end
6969
if max_valid == typemax(UInt)
7070
# if we aren't cached, we don't need this edge
@@ -79,60 +79,78 @@ function typeinf(interp::AbstractInterpreter, frame::InferenceState)
7979
return true
8080
end
8181

82-
# inference completed on `me`
83-
# update the MethodInstance and notify the edges
84-
function cache_result(interp::AbstractInterpreter, result::InferenceResult, min_valid::UInt, max_valid::UInt)
85-
def = result.linfo.def
86-
toplevel = !isa(result.linfo.def, Method)
8782

88-
# check if the existing linfo metadata is also sufficient to describe the current inference result
89-
# to decide if it is worth caching this
90-
already_inferred = !result.linfo.inInference
91-
if inf_for_methodinstance(interp, result.linfo, min_valid, max_valid) isa CodeInstance
92-
already_inferred = true
93-
end
94-
95-
# TODO: also don't store inferred code if we've previously decided to interpret this function
96-
if !already_inferred
97-
inferred_result = result.src
98-
if inferred_result isa Const
99-
# use constant calling convention
100-
rettype_const = (result.src::Const).val
101-
const_flags = 0x3
83+
function CodeInstance(result::InferenceResult, min_valid::UInt, max_valid::UInt,
84+
may_compress=true, allow_discard_tree=true)
85+
inferred_result = result.src
86+
local const_flags::Int32
87+
if inferred_result isa Const
88+
# use constant calling convention
89+
rettype_const = (result.src::Const).val
90+
const_flags = 0x3
91+
else
92+
if isa(result.result, Const)
93+
rettype_const = (result.result::Const).val
94+
const_flags = 0x2
95+
elseif isconstType(result.result)
96+
rettype_const = result.result.parameters[1]
97+
const_flags = 0x2
10298
else
103-
if isa(result.result, Const)
104-
rettype_const = (result.result::Const).val
105-
const_flags = 0x2
106-
elseif isconstType(result.result)
107-
rettype_const = result.result.parameters[1]
108-
const_flags = 0x2
109-
else
110-
rettype_const = nothing
111-
const_flags = 0x00
112-
end
113-
if !toplevel && inferred_result isa CodeInfo
114-
cache_the_tree = result.src.inferred &&
99+
rettype_const = nothing
100+
const_flags = 0x00
101+
end
102+
if inferred_result isa CodeInfo
103+
def = result.linfo.def
104+
toplevel = !isa(def, Method)
105+
if !toplevel
106+
cache_the_tree = !allow_discard_tree || (result.src.inferred &&
115107
(result.src.inlineable ||
116-
ccall(:jl_isa_compileable_sig, Int32, (Any, Any), result.linfo.specTypes, def) != 0)
108+
ccall(:jl_isa_compileable_sig, Int32, (Any, Any), result.linfo.specTypes, def) != 0))
117109
if cache_the_tree
118-
# compress code for non-toplevel thunks
119-
nslots = length(inferred_result.slotflags)
120-
resize!(inferred_result.slottypes, nslots)
121-
resize!(inferred_result.slotnames, nslots)
122-
inferred_result = ccall(:jl_compress_ir, Any, (Any, Any), def, inferred_result)
110+
if may_compress
111+
nslots = length(inferred_result.slotflags)
112+
resize!(inferred_result.slottypes, nslots)
113+
resize!(inferred_result.slotnames, nslots)
114+
inferred_result = ccall(:jl_compress_ir, Any, (Any, Any), def, inferred_result)
115+
end
123116
else
124117
inferred_result = nothing
125118
end
126119
end
127120
end
128-
if !isa(inferred_result, Union{CodeInfo, Vector{UInt8}})
129-
inferred_result = nothing
130-
end
131-
ccall(:jl_set_method_inferred, Ref{CodeInstance}, (Any, Any, Any, Any, Int32, UInt, UInt),
132-
result.linfo, widenconst(result.result), rettype_const, inferred_result,
133-
const_flags, min_valid, max_valid)
134121
end
135-
result.linfo.inInference = false
122+
if !isa(inferred_result, Union{CodeInfo, Vector{UInt8}})
123+
inferred_result = nothing
124+
end
125+
return CodeInstance(result.linfo,
126+
widenconst(result.result), rettype_const, inferred_result,
127+
const_flags, min_valid, max_valid)
128+
end
129+
130+
# For the NativeInterpreter, we don't need to do an actual cache query to know
131+
# if something was already inferred. If we reach this point, but the inference
132+
# flag has been turned off, then it's in the cache. This is purely a performance
133+
# optimization.
134+
already_inferred_quick_test(interp::NativeInterpreter, mi::MethodInstance) =
135+
!mi.inInference
136+
already_inferred_quick_test(interp::AbstractInterpreter, mi::MethodInstance) =
137+
false
138+
139+
# inference completed on `me`
140+
# update the MethodInstance
141+
function cache_result!(interp::AbstractInterpreter, result::InferenceResult, min_valid::UInt, max_valid::UInt)
142+
# check if the existing linfo metadata is also sufficient to describe the current inference result
143+
# to decide if it is worth caching this
144+
already_inferred = already_inferred_quick_test(interp, result.linfo)
145+
if !already_inferred && haskey(WorldView(code_cache(interp), min_valid, max_valid), result.linfo)
146+
already_inferred = true
147+
end
148+
149+
# TODO: also don't store inferred code if we've previously decided to interpret this function
150+
if !already_inferred
151+
code_cache(interp)[result.linfo] = CodeInstance(result, min_valid, max_valid)
152+
end
153+
unlock_mi_inference(interp, result.linfo)
136154
nothing
137155
end
138156

@@ -142,7 +160,7 @@ function finish(me::InferenceState, interp::AbstractInterpreter)
142160
# a top parent will be cached still, but not this intermediate work
143161
# we can throw everything else away now
144162
me.cached = false
145-
me.linfo.inInference = false
163+
unlock_mi_inference(interp, me.linfo)
146164
me.src.inlineable = false
147165
else
148166
# annotate fulltree with type information
@@ -452,7 +470,7 @@ end
452470
# compute (and cache) an inferred AST and return the current best estimate of the result type
453471
function typeinf_edge(interp::AbstractInterpreter, method::Method, @nospecialize(atypes), sparams::SimpleVector, caller::InferenceState)
454472
mi = specialize_method(method, atypes, sparams)::MethodInstance
455-
code = inf_for_methodinstance(interp, mi, get_world_counter(interp))
473+
code = get(code_cache(interp), mi, nothing)
456474
if code isa CodeInstance # return existing rettype if the code is already inferred
457475
update_valid_age!(min_world(code), max_world(code), caller)
458476
if isdefined(code, :rettype_const)
@@ -470,12 +488,12 @@ function typeinf_edge(interp::AbstractInterpreter, method::Method, @nospecialize
470488
end
471489
if frame === false
472490
# completely new
473-
mi.inInference = true
491+
lock_mi_inference(interp, mi)
474492
result = InferenceResult(mi)
475493
frame = InferenceState(result, #=cached=#true, interp) # always use the cache for edge targets
476494
if frame === nothing
477495
# can't get the source for this, so we know nothing
478-
mi.inInference = false
496+
unlock_mi_inference(interp, mi)
479497
return Any, nothing
480498
end
481499
if caller.cached || caller.limited # don't involve uncached functions in cycle resolution
@@ -524,7 +542,7 @@ function typeinf_ext(interp::AbstractInterpreter, mi::MethodInstance)
524542
method = mi.def::Method
525543
for i = 1:2 # test-and-lock-and-test
526544
i == 2 && ccall(:jl_typeinf_begin, Cvoid, ())
527-
code = inf_for_methodinstance(interp, mi, get_world_counter(interp))
545+
code = get(code_cache(interp), mi, nothing)
528546
if code isa CodeInstance
529547
# see if this code already exists in the cache
530548
inf = code.inferred
@@ -565,7 +583,7 @@ function typeinf_ext(interp::AbstractInterpreter, mi::MethodInstance)
565583
end
566584
end
567585
end
568-
mi.inInference = true
586+
lock_mi_inference(interp, mi)
569587
frame = InferenceState(InferenceResult(mi), #=cached=#true, interp)
570588
frame === nothing && return nothing
571589
typeinf(interp, frame)
@@ -582,7 +600,7 @@ function typeinf_type(interp::AbstractInterpreter, method::Method, @nospecialize
582600
mi = specialize_method(method, atypes, sparams)::MethodInstance
583601
for i = 1:2 # test-and-lock-and-test
584602
i == 2 && ccall(:jl_typeinf_begin, Cvoid, ())
585-
code = inf_for_methodinstance(interp, mi, get_world_counter(interp))
603+
code = get(code_cache(interp), mi, nothing)
586604
if code isa CodeInstance
587605
# see if this rettype already exists in the cache
588606
i == 2 && ccall(:jl_typeinf_end, Cvoid, ())

base/compiler/types.jl

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -175,3 +175,17 @@ InferenceParams(ni::NativeInterpreter) = ni.inf_params
175175
OptimizationParams(ni::NativeInterpreter) = ni.opt_params
176176
get_world_counter(ni::NativeInterpreter) = ni.world
177177
get_inference_cache(ni::NativeInterpreter) = ni.cache
178+
179+
code_cache(ni::NativeInterpreter) = WorldView(GLOBAL_CI_CACHE, ni.world)
180+
181+
"""
182+
lock_mi_inference(ni::NativeInterpreter, mi::MethodInstance)
183+
184+
Hint that `mi` is in inference to help accelerate bootstrapping. This helps limit the amount of wasted work we might do when inference is working on initially inferring itself by letting us detect when inference is already in progress and not running a second copy on it. This creates a data-race, but the entry point into this code from C (jl_type_infer) already includes detection and restriction on recursion, so it is hopefully mostly a benign problem (since it should really only happen during the first phase of bootstrapping that we encounter this flag).
185+
"""
186+
lock_mi_inference(ni::NativeInterpreter, mi::MethodInstance) = (mi.inInference = true; nothing)
187+
188+
"""
189+
See lock_mi_inference
190+
"""
191+
unlock_mi_inference(ni::NativeInterpreter, mi::MethodInstance) = (mi.inInference = false; nothing)

base/compiler/utilities.jl

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -118,11 +118,6 @@ function retrieve_code_info(linfo::MethodInstance)
118118
end
119119
end
120120

121-
function inf_for_methodinstance(interp::AbstractInterpreter, mi::MethodInstance, min_world::UInt, max_world::UInt=min_world)
122-
return ccall(:jl_rettype_inferred, Any, (Any, UInt, UInt), mi, min_world, max_world)::Union{Nothing, CodeInstance}
123-
end
124-
125-
126121
# get a handle to the unique specialization object representing a particular instantiation of a call
127122
function specialize_method(method::Method, @nospecialize(atypes), sparams::SimpleVector, preexisting::Bool=false)
128123
if preexisting

base/reflection.jl

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -981,17 +981,20 @@ struct CodegenParams
981981
emit_function::Any
982982
emitted_function::Any
983983

984+
lookup::Ptr{Cvoid}
985+
984986
function CodegenParams(; track_allocations::Bool=true, code_coverage::Bool=true,
985987
static_alloc::Bool=true, prefer_specsig::Bool=false,
986988
gnu_pubnames=true, debug_info_kind::Cint = default_debug_info_kind(),
987989
module_setup=nothing, module_activation=nothing, raise_exception=nothing,
988-
emit_function=nothing, emitted_function=nothing)
990+
emit_function=nothing, emitted_function=nothing,
991+
lookup::Ptr{Cvoid}=cglobal(:jl_rettype_inferred))
989992
return new(
990993
Cint(track_allocations), Cint(code_coverage),
991994
Cint(static_alloc), Cint(prefer_specsig),
992995
Cint(gnu_pubnames), debug_info_kind,
993996
module_setup, module_activation, raise_exception,
994-
emit_function, emitted_function)
997+
emit_function, emitted_function, lookup)
995998
end
996999
end
9971000

0 commit comments

Comments
 (0)