Skip to content

Commit c5fcd73

Browse files
staticfloatKeno
authored andcommitted
Add AbstractInterpreter to parameterize compilation pipeline
This allows selective overriding of the compilation pipeline through multiple dispatch, enabling projects like `XLA.jl` to maintain separate inference caches, inference algorithms or heuristic algorithms while inferring and lowering code. In particular, it defines a new type, `AbstractInterpreter`, that represents an abstract interpretation pipeline. This `AbstractInterpreter` has a single defined concrete subtype, `NativeInterpreter`, that represents the native Julia compilation pipeline. The `NativeInterpreter` contains within it all the compiler parameters previously contained within `Params`, split into two pieces: `InferenceParams` and `OptimizationParams`, used within type inference and optimization, respectively. The interpreter object is then threaded throughout most of the type inference pipeline, and allows for straightforward prototyping and replacement of the compiler internals. As a simple example of the kind of workflow this enables, I include here a simple testing script showing how to use this to easily get a list of the number of times a function is inferred during type inference by overriding just two functions within the compiler. First, I will define here some simple methods to make working with inference a bit easier: ```julia using Core.Compiler import Core.Compiler: InferenceParams, OptimizationParams, get_world_counter, get_inference_cache """ @infer_function interp foo(1, 2) [show_steps=true] [show_ir=false] Infer a function call using the given interpreter object, return the inference object. Set keyword arguments to modify verbosity: * Set `show_steps` to `true` to see the `InferenceResult` step by step. * Set `show_ir` to `true` to see the final type-inferred Julia IR. """ macro infer_function(interp, func_call, kwarg_exs...) if !isa(func_call, Expr) || func_call.head != :call error("@infer_function requires a function call") end local func = func_call.args[1] local args = func_call.args[2:end] kwargs = [] for ex in kwarg_exs if ex isa Expr && ex.head === :(=) && ex.args[1] isa Symbol push!(kwargs, first(ex.args) => last(ex.args)) else error("Invalid @infer_function kwarg $(ex)") end end return quote infer_function($(esc(interp)), $(esc(func)), typeof.(($(args)...,)); $(esc(kwargs))...) end end function infer_function(interp, f, tt; show_steps::Bool=false, show_ir::Bool=false) # Find all methods that are applicable to these types fms = methods(f, tt) if length(fms) != 1 error("Unable to find single applicable method for $f with types $tt") end # Take the first applicable method method = first(fms) # Build argument tuple method_args = Tuple{typeof(f), tt...} # Grab the appropriate method instance for these types mi = Core.Compiler.specialize_method(method, method_args, Core.svec()) # Construct InferenceResult to hold the result, result = Core.Compiler.InferenceResult(mi) if show_steps @info("Initial result, before inference: ", result) end # Create an InferenceState to begin inference, give it a world that is always newest world = Core.Compiler.get_world_counter() frame = Core.Compiler.InferenceState(result, #=cached=# true, interp) # Run type inference on this frame. Because the interpreter is embedded # within this InferenceResult, we don't need to pass the interpreter in. Core.Compiler.typeinf_local(interp, frame) if show_steps @info("Ending result, post-inference: ", result) end if show_ir @info("Inferred source: ", result.result.src) end # Give the result back return result end ``` Next, we define a simple function and pass it through: ```julia function foo(x, y) return x + y * x end native_interpreter = Core.Compiler.NativeInterpreter() inferred = @infer_function native_interpreter foo(1.0, 2.0) show_steps=true show_ir=true ``` This gives a nice output such as the following: ```julia-repl ┌ Info: Initial result, before inference: └ result = foo(::Float64, ::Float64) => Any ┌ Info: Ending result, post-inference: └ result = foo(::Float64, ::Float64) => Float64 ┌ Info: Inferred source: │ result.result.src = │ CodeInfo( │ @ REPL[1]:3 within `foo' │ 1 ─ %1 = (y * x)::Float64 │ │ %2 = (x + %1)::Float64 │ └── return %2 └ ) ``` We can then define a custom `AbstractInterpreter` subtype that will override two specific pieces of the compilation process; managing the runtime inference cache. While it will transparently pass all information through to a bundled `NativeInterpreter`, it has the ability to force cache misses in order to re-infer things so that we can easily see how many methods (and which) would be inferred to compile a certain method: ```julia struct CountingInterpreter <: Compiler.AbstractInterpreter visited_methods::Set{Core.Compiler.MethodInstance} methods_inferred::Ref{UInt64} # Keep around a native interpreter so that we can sub off to "super" functions native_interpreter::Core.Compiler.NativeInterpreter end CountingInterpreter() = CountingInterpreter( Set{Core.Compiler.MethodInstance}(), Ref(UInt64(0)), Core.Compiler.NativeInterpreter(), ) InferenceParams(ci::CountingInterpreter) = InferenceParams(ci.native_interpreter) OptimizationParams(ci::CountingInterpreter) = OptimizationParams(ci.native_interpreter) get_world_counter(ci::CountingInterpreter) = get_world_counter(ci.native_interpreter) get_inference_cache(ci::CountingInterpreter) = get_inference_cache(ci.native_interpreter) function Core.Compiler.inf_for_methodinstance(interp::CountingInterpreter, mi::Core.Compiler.MethodInstance, min_world::UInt, max_world::UInt=min_world) # Hit our own cache; if it exists, pass on to the main runtime if mi in interp.visited_methods return Core.Compiler.inf_for_methodinstance(interp.native_interpreter, mi, min_world, max_world) end # Otherwise, we return `nothing`, forcing a cache miss return nothing end function Core.Compiler.cache_result(interp::CountingInterpreter, result::Core.Compiler.InferenceResult, min_valid::UInt, max_valid::UInt) push!(interp.visited_methods, result.linfo) interp.methods_inferred[] += 1 return Core.Compiler.cache_result(interp.native_interpreter, result, min_valid, max_valid) end function reset!(interp::CountingInterpreter) empty!(interp.visited_methods) interp.methods_inferred[] = 0 return nothing end ``` Running it on our testing function: ```julia counting_interpreter = CountingInterpreter() inferred = @infer_function counting_interpreter foo(1.0, 2.0) @info("Cumulative number of methods inferred: $(counting_interpreter.methods_inferred[])") inferred = @infer_function counting_interpreter foo(1, 2) show_ir=true @info("Cumulative number of methods inferred: $(counting_interpreter.methods_inferred[])") inferred = @infer_function counting_interpreter foo(1.0, 2.0) @info("Cumulative number of methods inferred: $(counting_interpreter.methods_inferred[])") reset!(counting_interpreter) @info("Cumulative number of methods inferred: $(counting_interpreter.methods_inferred[])") inferred = @infer_function counting_interpreter foo(1.0, 2.0) @info("Cumulative number of methods inferred: $(counting_interpreter.methods_inferred[])") ``` Also gives us a nice result: ``` [ Info: Cumulative number of methods inferred: 2 ┌ Info: Inferred source: │ result.result.src = │ CodeInfo( │ @ /Users/sabae/src/julia-compilerhack/AbstractInterpreterTest.jl:81 within `foo' │ 1 ─ %1 = (y * x)::Int64 │ │ %2 = (x + %1)::Int64 │ └── return %2 └ ) [ Info: Cumulative number of methods inferred: 4 [ Info: Cumulative number of methods inferred: 4 [ Info: Cumulative number of methods inferred: 0 [ Info: Cumulative number of methods inferred: 2 ```
1 parent 8f512f3 commit c5fcd73

18 files changed

+392
-269
lines changed

base/compiler/abstractinterpretation.jl

Lines changed: 75 additions & 72 deletions
Large diffs are not rendered by default.

base/compiler/bootstrap.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,9 @@
66
# since we won't be able to specialize & infer them at runtime
77

88
let fs = Any[typeinf_ext, typeinf, typeinf_edge, pure_eval_call, run_passes],
9-
world = get_world_counter()
9+
world = get_world_counter(),
10+
interp = NativeInterpreter(world)
11+
1012
for x in T_FFUNC_VAL
1113
push!(fs, x[3])
1214
end
@@ -27,7 +29,7 @@ let fs = Any[typeinf_ext, typeinf, typeinf_edge, pure_eval_call, run_passes],
2729
typ[i] = typ[i].ub
2830
end
2931
end
30-
typeinf_type(m[3], Tuple{typ...}, m[2], Params(world))
32+
typeinf_type(interp, m[3], Tuple{typ...}, m[2])
3133
end
3234
end
3335
end

base/compiler/compiler.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,11 +94,11 @@ using .Sort
9494
# compiler #
9595
############
9696

97+
include("compiler/types.jl")
9798
include("compiler/utilities.jl")
9899
include("compiler/validation.jl")
99100

100101
include("compiler/inferenceresult.jl")
101-
include("compiler/params.jl")
102102
include("compiler/inferencestate.jl")
103103

104104
include("compiler/typeutils.jl")

base/compiler/inferenceresult.jl

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2,18 +2,6 @@
22

33
const EMPTY_VECTOR = Vector{Any}()
44

5-
mutable struct InferenceResult
6-
linfo::MethodInstance
7-
argtypes::Vector{Any}
8-
overridden_by_const::BitVector
9-
result # ::Type, or InferenceState if WIP
10-
src #::Union{CodeInfo, OptimizationState, Nothing} # if inferred copy is available
11-
function InferenceResult(linfo::MethodInstance, given_argtypes = nothing)
12-
argtypes, overridden_by_const = matching_cache_argtypes(linfo, given_argtypes)
13-
return new(linfo, argtypes, overridden_by_const, Any, nothing)
14-
end
15-
end
16-
175
function is_argtype_match(@nospecialize(given_argtype),
186
@nospecialize(cache_argtype),
197
overridden_by_const::Bool)

base/compiler/inferencestate.jl

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
const LineNum = Int
44

55
mutable struct InferenceState
6-
params::Params # describes how to compute the result
6+
params::InferenceParams
77
result::InferenceResult # remember where to put the result
88
linfo::MethodInstance
99
sptypes::Vector{Any} # types of static parameter
@@ -13,6 +13,7 @@ mutable struct InferenceState
1313

1414
# info on the state of inference and the linfo
1515
src::CodeInfo
16+
world::UInt
1617
min_valid::UInt
1718
max_valid::UInt
1819
nargs::Int
@@ -47,7 +48,7 @@ mutable struct InferenceState
4748

4849
# src is assumed to be a newly-allocated CodeInfo, that can be modified in-place to contain intermediate results
4950
function InferenceState(result::InferenceResult, src::CodeInfo,
50-
cached::Bool, params::Params)
51+
cached::Bool, interp::AbstractInterpreter)
5152
linfo = result.linfo
5253
code = src.code::Array{Any,1}
5354
toplevel = !isa(linfo.def, Method)
@@ -95,9 +96,9 @@ mutable struct InferenceState
9596
max_valid = src.max_world == typemax(UInt) ?
9697
get_world_counter() : src.max_world
9798
frame = new(
98-
params, result, linfo,
99+
InferenceParams(interp), result, linfo,
99100
sp, slottypes, inmodule, 0,
100-
src, min_valid, max_valid,
101+
src, get_world_counter(interp), min_valid, max_valid,
101102
nargs, s_types, s_edges,
102103
Union{}, W, 1, n,
103104
cur_hand, handler_at, n_handlers,
@@ -108,17 +109,17 @@ mutable struct InferenceState
108109
cached, false, false, false,
109110
IdDict{Any, Tuple{Any, UInt, UInt}}())
110111
result.result = frame
111-
cached && push!(params.cache, result)
112+
cached && push!(get_inference_cache(interp), result)
112113
return frame
113114
end
114115
end
115116

116-
function InferenceState(result::InferenceResult, cached::Bool, params::Params)
117+
function InferenceState(result::InferenceResult, cached::Bool, interp::AbstractInterpreter)
117118
# prepare an InferenceState object for inferring lambda
118119
src = retrieve_code_info(result.linfo)
119120
src === nothing && return nothing
120121
validate_code_in_debug_mode(result.linfo, src, "lowered")
121-
return InferenceState(result, src, cached, params)
122+
return InferenceState(result, src, cached, interp)
122123
end
123124

124125
function sptypes_from_meth_instance(linfo::MethodInstance)
@@ -195,8 +196,7 @@ _topmod(sv::InferenceState) = _topmod(sv.mod)
195196
function update_valid_age!(min_valid::UInt, max_valid::UInt, sv::InferenceState)
196197
sv.min_valid = max(sv.min_valid, min_valid)
197198
sv.max_valid = min(sv.max_valid, max_valid)
198-
@assert(sv.min_valid <= sv.params.world <= sv.max_valid,
199-
"invalid age range update")
199+
@assert(sv.min_valid <= sv.world <= sv.max_valid, "invalid age range update")
200200
nothing
201201
end
202202

base/compiler/optimize.jl

Lines changed: 26 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -5,36 +5,38 @@
55
#####################
66

77
mutable struct OptimizationState
8+
params::OptimizationParams
89
linfo::MethodInstance
910
calledges::Vector{Any}
1011
src::CodeInfo
1112
mod::Module
1213
nargs::Int
14+
world::UInt
1315
min_valid::UInt
1416
max_valid::UInt
15-
params::Params
1617
sptypes::Vector{Any} # static parameters
1718
slottypes::Vector{Any}
1819
const_api::Bool
1920
# cached results of calling `_methods_by_ftype` from inference, including
2021
# `min_valid` and `max_valid`
2122
matching_methods_cache::IdDict{Any, Tuple{Any, UInt, UInt}}
22-
function OptimizationState(frame::InferenceState)
23+
# TODO: This will be eliminated once optimization no longer needs to do method lookups
24+
interp::AbstractInterpreter
25+
function OptimizationState(frame::InferenceState, params::OptimizationParams, interp::AbstractInterpreter)
2326
s_edges = frame.stmt_edges[1]
2427
if s_edges === nothing
2528
s_edges = []
2629
frame.stmt_edges[1] = s_edges
2730
end
2831
src = frame.src
29-
return new(frame.linfo,
32+
return new(params, frame.linfo,
3033
s_edges::Vector{Any},
3134
src, frame.mod, frame.nargs,
32-
frame.min_valid, frame.max_valid,
33-
frame.params, frame.sptypes, frame.slottypes, false,
34-
frame.matching_methods_cache)
35+
frame.world, frame.min_valid, frame.max_valid,
36+
frame.sptypes, frame.slottypes, false,
37+
frame.matching_methods_cache, interp)
3538
end
36-
function OptimizationState(linfo::MethodInstance, src::CodeInfo,
37-
params::Params)
39+
function OptimizationState(linfo::MethodInstance, src::CodeInfo, params::OptimizationParams, interp::AbstractInterpreter)
3840
# prepare src for running optimization passes
3941
# if it isn't already
4042
nssavalues = src.ssavaluetypes
@@ -57,19 +59,19 @@ mutable struct OptimizationState
5759
inmodule = linfo.def::Module
5860
nargs = 0
5961
end
60-
return new(linfo,
62+
return new(params, linfo,
6163
s_edges::Vector{Any},
6264
src, inmodule, nargs,
63-
UInt(1), get_world_counter(),
64-
params, sptypes_from_meth_instance(linfo), slottypes, false,
65-
IdDict{Any, Tuple{Any, UInt, UInt}}())
65+
get_world_counter(), UInt(1), get_world_counter(),
66+
sptypes_from_meth_instance(linfo), slottypes, false,
67+
IdDict{Any, Tuple{Any, UInt, UInt}}(), interp)
6668
end
6769
end
6870

69-
function OptimizationState(linfo::MethodInstance, params::Params)
71+
function OptimizationState(linfo::MethodInstance, params::OptimizationParams, interp::AbstractInterpreter)
7072
src = retrieve_code_info(linfo)
7173
src === nothing && return nothing
72-
return OptimizationState(linfo, src, params)
74+
return OptimizationState(linfo, src, params, interp)
7375
end
7476

7577

@@ -109,7 +111,7 @@ _topmod(sv::OptimizationState) = _topmod(sv.mod)
109111
function update_valid_age!(min_valid::UInt, max_valid::UInt, sv::OptimizationState)
110112
sv.min_valid = max(sv.min_valid, min_valid)
111113
sv.max_valid = min(sv.max_valid, max_valid)
112-
@assert(sv.min_valid <= sv.params.world <= sv.max_valid,
114+
@assert(sv.min_valid <= sv.world <= sv.max_valid,
113115
"invalid age range update")
114116
nothing
115117
end
@@ -127,10 +129,10 @@ function add_backedge!(li::CodeInstance, caller::OptimizationState)
127129
nothing
128130
end
129131

130-
function isinlineable(m::Method, me::OptimizationState, bonus::Int=0)
132+
function isinlineable(m::Method, me::OptimizationState, params::OptimizationParams, bonus::Int=0)
131133
# compute the cost (size) of inlining this code
132134
inlineable = false
133-
cost_threshold = me.params.inline_cost_threshold
135+
cost_threshold = params.inline_cost_threshold
134136
if m.module === _topmod(m.module)
135137
# a few functions get special treatment
136138
name = m.name
@@ -145,7 +147,7 @@ function isinlineable(m::Method, me::OptimizationState, bonus::Int=0)
145147
end
146148
end
147149
if !inlineable
148-
inlineable = inline_worthy(me.src.code, me.src, me.sptypes, me.slottypes, me.params, cost_threshold + bonus)
150+
inlineable = inline_worthy(me.src.code, me.src, me.sptypes, me.slottypes, params, cost_threshold + bonus)
149151
end
150152
return inlineable
151153
end
@@ -168,7 +170,7 @@ function stmt_affects_purity(@nospecialize(stmt), ir)
168170
end
169171

170172
# run the optimization work
171-
function optimize(opt::OptimizationState, @nospecialize(result))
173+
function optimize(opt::OptimizationState, params::OptimizationParams, @nospecialize(result))
172174
def = opt.linfo.def
173175
nargs = Int(opt.nargs) - 1
174176
@timeit "optimizer" ir = run_passes(opt.src, nargs, opt)
@@ -247,13 +249,13 @@ function optimize(opt::OptimizationState, @nospecialize(result))
247249
else
248250
bonus = 0
249251
if result Tuple && !isbitstype(widenconst(result))
250-
bonus = opt.params.inline_tupleret_bonus
252+
bonus = params.inline_tupleret_bonus
251253
end
252254
if opt.src.inlineable
253255
# For functions declared @inline, increase the cost threshold 20x
254-
bonus += opt.params.inline_cost_threshold*19
256+
bonus += params.inline_cost_threshold*19
255257
end
256-
opt.src.inlineable = isinlineable(def, opt, bonus)
258+
opt.src.inlineable = isinlineable(def, opt, params, bonus)
257259
end
258260
end
259261
nothing
@@ -282,7 +284,7 @@ plus_saturate(x::Int, y::Int) = max(x, y, x+y)
282284
# known return type
283285
isknowntype(@nospecialize T) = (T === Union{}) || isa(T, Const) || isconcretetype(widenconst(T))
284286

285-
function statement_cost(ex::Expr, line::Int, src::CodeInfo, sptypes::Vector{Any}, slottypes::Vector{Any}, params::Params)
287+
function statement_cost(ex::Expr, line::Int, src::CodeInfo, sptypes::Vector{Any}, slottypes::Vector{Any}, params::OptimizationParams)
286288
head = ex.head
287289
if is_meta_expr_head(head)
288290
return 0
@@ -372,7 +374,7 @@ function statement_cost(ex::Expr, line::Int, src::CodeInfo, sptypes::Vector{Any}
372374
end
373375

374376
function inline_worthy(body::Array{Any,1}, src::CodeInfo, sptypes::Vector{Any}, slottypes::Vector{Any},
375-
params::Params, cost_threshold::Integer=params.inline_cost_threshold)
377+
params::OptimizationParams, cost_threshold::Integer=params.inline_cost_threshold)
376378
bodycost::Int = 0
377379
for line = 1:length(body)
378380
stmt = body[line]

base/compiler/params.jl

Lines changed: 0 additions & 72 deletions
This file was deleted.

0 commit comments

Comments
 (0)