Skip to content

Commit 1e6e656

Browse files
authored
Refactor compiler method lookup interface (#36743)
The primary motivation here is to clean up the notion of "looking up a method in the method table" into a single object that can be passed around. At the moment, it's all a bit murky, with fairly large state objects being passed around everywhere and implicit accesses to the global environment. In my AD use case, I need to be a bit careful to make sure the various inference and optimization steps are looking things up in the correct tables/caches, so being very explicit about where things need to be looked up is quite helpful. In particular, I would like to clean up the optimizer, to not require the big `OptimizationState` which is currently a bit of a mix of things that go into IRCode and information needed for method lookup/edge tracking. That isn't part of this PR, but will build on top of it. More generally, with a bunch of the recent compiler work, I've been trying to define more crisp boundaries between the various components of the system, giving them clearer interfaces, and at least a little bit of documentation. The compiler is a very powerful bit of technology, but I think people having been avoiding it, because the code looks a bit scary. I'm hoping some of these cleanups will make it easier for people to understand what's going on. Here in particular, I'm using `findall(sig, table)` as the predicate for method lookup. The idea being that people familiar with the `findall(predicate, collection)` idiom from regular julia will have a good intuitive understanding of what's happening (a collection is searched for a predicate), an array of matches is returned, etc. Of course, it's not a perfect fit, but I think these kinds of mental aids can be helpful in making it easier for people to read compiler code (similar to how #35831 used `getindex` as the verb for cache lookup). While I was at it, I also cleaned up the use of out-parameters which leaked through too much of the underlying C API and replaced them by a proper struct of results.
1 parent 3699192 commit 1e6e656

File tree

14 files changed

+213
-129
lines changed

14 files changed

+213
-129
lines changed

base/compiler/abstractinterpretation.jl

Lines changed: 15 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -16,31 +16,13 @@ const _REF_NAME = Ref.body.name
1616
call_result_unused(frame::InferenceState, pc::LineNum=frame.currpc) =
1717
isexpr(frame.src.code[frame.currpc], :call) && isempty(frame.ssavalue_uses[pc])
1818

19-
function matching_methods(@nospecialize(atype), cache::IdDict{Any, Tuple{Any, UInt, UInt, Bool}}, max_methods::Int, world::UInt)
20-
box = Core.Box(atype)
21-
return get!(cache, atype) do
22-
_min_val = UInt[typemin(UInt)]
23-
_max_val = UInt[typemax(UInt)]
24-
_ambig = Int32[0]
25-
ms = _methods_by_ftype(box.contents, max_methods, world, false, _min_val, _max_val, _ambig)
26-
return ms, _min_val[1], _max_val[1], _ambig[1] != 0
27-
end
28-
end
29-
30-
function matching_methods(@nospecialize(atype), cache::IdDict{Any, Tuple{Any, UInt, UInt, Bool}}, max_methods::Int, world::UInt, min_valid::Vector{UInt}, max_valid::Vector{UInt})
31-
ms, minvalid, maxvalid, ambig = matching_methods(atype, cache, max_methods, world)
32-
min_valid[1] = max(min_valid[1], minvalid)
33-
max_valid[1] = min(max_valid[1], maxvalid)
34-
return ms, ambig
35-
end
3619

3720
function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f), argtypes::Vector{Any}, @nospecialize(atype), sv::InferenceState,
3821
max_methods::Int = InferenceParams(interp).MAX_METHODS)
3922
if sv.currpc in sv.throw_blocks
4023
return CallMeta(Any, false)
4124
end
42-
min_valid = UInt[typemin(UInt)]
43-
max_valid = UInt[typemax(UInt)]
25+
valid_worlds = WorldRange()
4426
atype_params = unwrap_unionall(atype).parameters
4527
splitunions = 1 < countunionsplit(atype_params) <= InferenceParams(interp).MAX_UNION_SPLITTING
4628
mts = Core.MethodTable[]
@@ -56,15 +38,15 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
5638
return CallMeta(Any, false)
5739
end
5840
mt = mt::Core.MethodTable
59-
xapplicable, ambig = matching_methods(sig_n, sv.matching_methods_cache, max_methods,
60-
get_world_counter(interp), min_valid, max_valid)
61-
if xapplicable === false
41+
matches = findall(sig_n, method_table(interp); limit=max_methods)
42+
if matches === missing
6243
add_remark!(interp, sv, "For one of the union split cases, too many methods matched")
6344
return CallMeta(Any, false)
6445
end
65-
push!(infos, MethodMatchInfo(xapplicable, ambig))
66-
append!(applicable, xapplicable)
67-
thisfullmatch = _any(match->(match::MethodMatch).fully_covers, xapplicable)
46+
push!(infos, MethodMatchInfo(matches))
47+
append!(applicable, matches)
48+
valid_worlds = intersect(valid_worlds, matches.valid_worlds)
49+
thisfullmatch = _any(match->(match::MethodMatch).fully_covers, matches)
6850
found = false
6951
for (i, mt′) in enumerate(mts)
7052
if mt′ === mt
@@ -86,19 +68,20 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
8668
return CallMeta(Any, false)
8769
end
8870
mt = mt::Core.MethodTable
89-
applicable, ambig = matching_methods(atype, sv.matching_methods_cache, max_methods,
90-
get_world_counter(interp), min_valid, max_valid)
91-
if applicable === false
71+
matches = findall(atype, method_table(interp, sv); limit=max_methods)
72+
if matches === missing
9273
# this means too many methods matched
9374
# (assume this will always be true, so we don't compute / update valid age in this case)
9475
add_remark!(interp, sv, "Too many methods matched")
9576
return CallMeta(Any, false)
9677
end
9778
push!(mts, mt)
98-
push!(fullmatch, _any(match->(match::MethodMatch).fully_covers, applicable))
99-
info = MethodMatchInfo(applicable, ambig)
79+
push!(fullmatch, _any(match->(match::MethodMatch).fully_covers, matches))
80+
info = MethodMatchInfo(matches)
81+
applicable = matches.matches
82+
valid_worlds = matches.valid_worlds
10083
end
101-
update_valid_age!(min_valid[1], max_valid[1], sv)
84+
update_valid_age!(sv, valid_worlds)
10285
applicable = applicable::Array{Any,1}
10386
napplicable = length(applicable)
10487
rettype = Bottom
@@ -1460,12 +1443,7 @@ function typeinf_nocycle(interp::AbstractInterpreter, frame::InferenceState)
14601443
typeinf_local(interp, caller)
14611444
no_active_ips_in_callers = false
14621445
end
1463-
if caller.min_valid < frame.min_valid
1464-
caller.min_valid = frame.min_valid
1465-
end
1466-
if caller.max_valid > frame.max_valid
1467-
caller.max_valid = frame.max_valid
1468-
end
1446+
caller.valid_worlds = intersect(caller.valid_worlds, frame.valid_worlds)
14691447
end
14701448
end
14711449
return true

base/compiler/cicache.jl

Lines changed: 24 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,23 @@ end
1414

1515
const GLOBAL_CI_CACHE = InternalCodeCache()
1616

17+
struct WorldRange
18+
min_world::UInt
19+
max_world::UInt
20+
end
21+
WorldRange() = WorldRange(typemin(UInt), typemax(UInt))
22+
WorldRange(w::UInt) = WorldRange(w, w)
23+
WorldRange(r::UnitRange) = WorldRange(first(r), last(r))
24+
first(wr::WorldRange) = wr.min_world
25+
last(wr::WorldRange) = wr.max_world
26+
in(world::UInt, wr::WorldRange) = wr.min_world <= world <= wr.max_world
27+
28+
function intersect(a::WorldRange, b::WorldRange)
29+
ret = WorldRange(max(a.min_world, b.min_world), min(a.max_world, b.max_world))
30+
@assert ret.min_world <= ret.max_world
31+
return ret
32+
end
33+
1734
"""
1835
struct WorldView
1936
@@ -22,20 +39,19 @@ range of world ages, rather than defaulting to the current active world age.
2239
"""
2340
struct WorldView{Cache}
2441
cache::Cache
25-
min_world::UInt
26-
max_world::UInt
42+
worlds::WorldRange
43+
WorldView(cache::Cache, range::WorldRange) where Cache = new{Cache}(cache, range)
2744
end
28-
WorldView(cache, r::UnitRange) = WorldView(cache, first(r), last(r))
29-
WorldView(cache, world::UInt) = WorldView(cache, world, world)
30-
WorldView(wvc::WorldView, min_world::UInt, max_world::UInt) =
31-
WorldView(wvc.cache, min_world, max_world)
45+
WorldView(cache, args...) = WorldView(cache, WorldRange(args...))
46+
WorldView(wvc::WorldView, wr::WorldRange) = WorldView(wvc.cache, wr)
47+
WorldView(wvc::WorldView, args...) = WorldView(wvc.cache, args...)
3248

3349
function haskey(wvc::WorldView{InternalCodeCache}, mi::MethodInstance)
34-
ccall(:jl_rettype_inferred, Any, (Any, UInt, UInt), mi, wvc.min_world, wvc.max_world)::Union{Nothing, CodeInstance} !== nothing
50+
ccall(:jl_rettype_inferred, Any, (Any, UInt, UInt), mi, first(wvc.worlds), last(wvc.worlds))::Union{Nothing, CodeInstance} !== nothing
3551
end
3652

3753
function get(wvc::WorldView{InternalCodeCache}, mi::MethodInstance, default)
38-
r = ccall(:jl_rettype_inferred, Any, (Any, UInt, UInt), mi, wvc.min_world, wvc.max_world)::Union{Nothing, CodeInstance}
54+
r = ccall(:jl_rettype_inferred, Any, (Any, UInt, UInt), mi, first(wvc.worlds), last(wvc.worlds))::Union{Nothing, CodeInstance}
3955
if r === nothing
4056
return default
4157
end

base/compiler/compiler.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,9 +107,11 @@ include("compiler/types.jl")
107107
include("compiler/utilities.jl")
108108
include("compiler/validation.jl")
109109

110+
include("compiler/cicache.jl")
111+
include("compiler/methodtable.jl")
112+
110113
include("compiler/inferenceresult.jl")
111114
include("compiler/inferencestate.jl")
112-
include("compiler/cicache.jl")
113115

114116
include("compiler/typeutils.jl")
115117
include("compiler/typelimits.jl")

base/compiler/inferencestate.jl

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,7 @@ mutable struct InferenceState
1414
# info on the state of inference and the linfo
1515
src::CodeInfo
1616
world::UInt
17-
min_valid::UInt
18-
max_valid::UInt
17+
valid_worlds::WorldRange
1918
nargs::Int
2019
stmt_types::Vector{Any}
2120
stmt_edges::Vector{Any}
@@ -44,9 +43,10 @@ mutable struct InferenceState
4443
inferred::Bool
4544
dont_work_on_me::Bool
4645

47-
# cached results of calling `_methods_by_ftype`, including `min_valid` and
48-
# `max_valid`, to be used in inlining
49-
matching_methods_cache::IdDict{Any, Tuple{Any, UInt, UInt, Bool}}
46+
# The place to look up methods while working on this function.
47+
# In particular, we cache method lookup results for the same function to
48+
# fast path repeated queries.
49+
method_table::CachedMethodTable{InternalMethodTable}
5050

5151
# The interpreter that created this inference state. Not looked at by
5252
# NativeInterpreter. But other interpreters may use this to detect cycles
@@ -100,13 +100,12 @@ mutable struct InferenceState
100100
inmodule = linfo.def::Module
101101
end
102102

103-
min_valid = src.min_world
104-
max_valid = src.max_world == typemax(UInt) ?
105-
get_world_counter() : src.max_world
103+
valid_worlds = WorldRange(src.min_world,
104+
src.max_world == typemax(UInt) ? get_world_counter() : src.max_world)
106105
frame = new(
107106
InferenceParams(interp), result, linfo,
108107
sp, slottypes, inmodule, 0,
109-
src, get_world_counter(interp), min_valid, max_valid,
108+
src, get_world_counter(interp), valid_worlds,
110109
nargs, s_types, s_edges, stmt_info,
111110
Union{}, W, 1, n,
112111
cur_hand, handler_at, n_handlers,
@@ -115,14 +114,16 @@ mutable struct InferenceState
115114
Vector{InferenceState}(), # callers_in_cycle
116115
#=parent=#nothing,
117116
cached, false, false, false,
118-
IdDict{Any, Tuple{Any, UInt, UInt, Bool}}(),
117+
CachedMethodTable(method_table(interp)),
119118
interp)
120119
result.result = frame
121120
cached && push!(get_inference_cache(interp), result)
122121
return frame
123122
end
124123
end
125124

125+
method_table(interp::AbstractInterpreter, sv::InferenceState) = sv.method_table
126+
126127
function InferenceState(result::InferenceResult, cached::Bool, interp::AbstractInterpreter)
127128
# prepare an InferenceState object for inferring lambda
128129
src = retrieve_code_info(result.linfo)
@@ -202,14 +203,13 @@ end
202203
_topmod(sv::InferenceState) = _topmod(sv.mod)
203204

204205
# work towards converging the valid age range for sv
205-
function update_valid_age!(min_valid::UInt, max_valid::UInt, sv::InferenceState)
206-
sv.min_valid = max(sv.min_valid, min_valid)
207-
sv.max_valid = min(sv.max_valid, max_valid)
208-
@assert(sv.min_valid <= sv.world <= sv.max_valid, "invalid age range update")
206+
function update_valid_age!(sv::InferenceState, worlds::WorldRange)
207+
sv.valid_worlds = intersect(worlds, sv.valid_worlds)
208+
@assert(sv.world in sv.valid_worlds, "invalid age range update")
209209
nothing
210210
end
211211

212-
update_valid_age!(edge::InferenceState, sv::InferenceState) = update_valid_age!(edge.min_valid, edge.max_valid, sv)
212+
update_valid_age!(edge::InferenceState, sv::InferenceState) = update_valid_age!(sv, edge.valid_worlds)
213213

214214
function record_ssa_assign(ssa_id::Int, @nospecialize(new), frame::InferenceState)
215215
old = frame.src.ssavaluetypes[ssa_id]

base/compiler/methodtable.jl

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
abstract type MethodTableView; end
2+
3+
struct MethodLookupResult
4+
# Really Vector{Core.MethodMatch}, but it's easier to represent this as
5+
# and work with Vector{Any} on the C side.
6+
matches::Vector{Any}
7+
valid_worlds::WorldRange
8+
ambig::Bool
9+
end
10+
length(result::MethodLookupResult) = length(result.matches)
11+
function iterate(result::MethodLookupResult, args...)
12+
r = iterate(result.matches, args...)
13+
r === nothing && return nothing
14+
match, state = r
15+
return (match::MethodMatch, state)
16+
end
17+
getindex(result::MethodLookupResult, idx::Int) = getindex(result.matches, idx)::MethodMatch
18+
19+
"""
20+
struct InternalMethodTable <: MethodTableView
21+
22+
A struct representing the state of the internal method table at a
23+
particular world age.
24+
"""
25+
struct InternalMethodTable <: MethodTableView
26+
world::UInt
27+
end
28+
29+
"""
30+
struct CachedMethodTable <: MethodTableView
31+
32+
Overlays another method table view with an additional local fast path cache that
33+
can respond to repeated, identical queries faster than the original method table.
34+
"""
35+
struct CachedMethodTable{T} <: MethodTableView
36+
cache::IdDict{Any, Union{Missing, MethodLookupResult}}
37+
table::T
38+
end
39+
CachedMethodTable(table::T) where T =
40+
CachedMethodTable{T}(IdDict{Any, Union{Missing, MethodLookupResult}}(),
41+
table)
42+
43+
"""
44+
findall(sig::Type{<:Tuple}, view::MethodTableView; limit=typemax(Int))
45+
46+
Find all methods in the given method table `view` that are applicable to the
47+
given signature `sig`. If no applicable methods are found, an empty result is
48+
returned. If the number of applicable methods exeeded the specified limit,
49+
`missing` is returned.
50+
"""
51+
function findall(@nospecialize(sig::Type{<:Tuple}), table::InternalMethodTable; limit::Int=typemax(Int))
52+
_min_val = RefValue{UInt}(typemin(UInt))
53+
_max_val = RefValue{UInt}(typemax(UInt))
54+
_ambig = RefValue{Int32}(0)
55+
ms = _methods_by_ftype(sig, limit, table.world, false, _min_val, _max_val, _ambig)
56+
if ms === false
57+
return missing
58+
end
59+
return MethodLookupResult(ms::Vector{Any}, WorldRange(_min_val[], _max_val[]), _ambig[] != 0)
60+
end
61+
62+
function findall(@nospecialize(sig::Type{<:Tuple}), table::CachedMethodTable; limit::Int=typemax(Int))
63+
box = Core.Box(sig)
64+
return get!(table.cache, sig) do
65+
findall(box.contents, table.table; limit=limit)
66+
end
67+
end
68+
69+
"""
70+
findsup(sig::Type{<:Tuple}, view::MethodTableView)::Union{Tuple{MethodMatch, WorldRange}, Nothing}
71+
72+
Find the (unique) method `m` such that `sig <: m.sig`, while being more
73+
specific than any other method with the same property. In other words, find
74+
the method which is the least upper bound (supremum) under the specificity/subtype
75+
relation of the queried `signature`. If `sig` is concrete, this is equivalent to
76+
asking for the method that will be called given arguments whose types match the
77+
given signature. This query is also used to implement `invoke`.
78+
79+
Such a method `m` need not exist. It is possible that no method is an
80+
upper bound of `sig`, or it is possible that among the upper bounds, there
81+
is no least element. In both cases `nothing` is returned.
82+
"""
83+
function findsup(@nospecialize(sig::Type{<:Tuple}), table::InternalMethodTable)
84+
min_valid = RefValue{UInt}(typemin(UInt))
85+
max_valid = RefValue{UInt}(typemax(UInt))
86+
result = ccall(:jl_gf_invoke_lookup_worlds, Any, (Any, UInt, Ptr{Csize_t}, Ptr{Csize_t}),
87+
sig, table.world, min_valid, max_valid)::Union{Method, Nothing}
88+
result === nothing && return nothing
89+
(result, WorldRange(min_valid[], max_valid[]))
90+
end
91+
92+
# This query is not cached
93+
findsup(sig::Type{<:Tuple}, table::CachedMethodTable) = findsup(sig, table.table)

base/compiler/optimize.jl

Lines changed: 9 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,10 @@ mutable struct OptimizationState
1313
mod::Module
1414
nargs::Int
1515
world::UInt
16-
min_valid::UInt
17-
max_valid::UInt
16+
valid_worlds::WorldRange
1817
sptypes::Vector{Any} # static parameters
1918
slottypes::Vector{Any}
2019
const_api::Bool
21-
# cached results of calling `_methods_by_ftype` from inference, including
22-
# `min_valid` and `max_valid`
23-
matching_methods_cache::IdDict{Any, Tuple{Any, UInt, UInt, Bool}}
2420
# TODO: This will be eliminated once optimization no longer needs to do method lookups
2521
interp::AbstractInterpreter
2622
function OptimizationState(frame::InferenceState, params::OptimizationParams, interp::AbstractInterpreter)
@@ -33,9 +29,9 @@ mutable struct OptimizationState
3329
return new(params, frame.linfo,
3430
s_edges::Vector{Any},
3531
src, frame.stmt_info, frame.mod, frame.nargs,
36-
frame.world, frame.min_valid, frame.max_valid,
32+
frame.world, frame.valid_worlds,
3733
frame.sptypes, frame.slottypes, false,
38-
frame.matching_methods_cache, interp)
34+
interp)
3935
end
4036
function OptimizationState(linfo::MethodInstance, src::CodeInfo, params::OptimizationParams, interp::AbstractInterpreter)
4137
# prepare src for running optimization passes
@@ -64,9 +60,9 @@ mutable struct OptimizationState
6460
return new(params, linfo,
6561
s_edges::Vector{Any},
6662
src, stmt_info, inmodule, nargs,
67-
get_world_counter(), UInt(1), get_world_counter(),
63+
get_world_counter(), WorldRange(UInt(1), get_world_counter()),
6864
sptypes_from_meth_instance(linfo), slottypes, false,
69-
IdDict{Any, Tuple{Any, UInt, UInt, Bool}}(), interp)
65+
interp)
7066
end
7167
end
7268

@@ -110,11 +106,9 @@ const TOP_TUPLE = GlobalRef(Core, :tuple)
110106

111107
_topmod(sv::OptimizationState) = _topmod(sv.mod)
112108

113-
function update_valid_age!(min_valid::UInt, max_valid::UInt, sv::OptimizationState)
114-
sv.min_valid = max(sv.min_valid, min_valid)
115-
sv.max_valid = min(sv.max_valid, max_valid)
116-
@assert(sv.min_valid <= sv.world <= sv.max_valid,
117-
"invalid age range update")
109+
function update_valid_age!(sv::OptimizationState, valid_worlds::WorldRange)
110+
sv.valid_worlds = intersect(sv.valid_worlds, valid_worlds)
111+
@assert(sv.world in sv.valid_worlds, "invalid age range update")
118112
nothing
119113
end
120114

@@ -126,7 +120,7 @@ function add_backedge!(li::MethodInstance, caller::OptimizationState)
126120
end
127121

128122
function add_backedge!(li::CodeInstance, caller::OptimizationState)
129-
update_valid_age!(min_world(li), max_world(li), caller)
123+
update_valid_age!(caller, WorldRange(min_world(li), max_world(li)))
130124
add_backedge!(li.def, caller)
131125
nothing
132126
end

0 commit comments

Comments
 (0)