You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Add AbstractInterpreter to parameterize compilation pipeline
This allows selective overriding of the compilation pipeline through
multiple dispatch, enabling projects like `XLA.jl` to maintain separate
inference caches, inference algorithms or heuristic algorithms while
inferring and lowering code. In particular, it defines a new type,
`AbstractInterpreter`, that represents an abstract interpretation
pipeline. This `AbstractInterpreter` has a single defined concrete
subtype, `NativeInterpreter`, that represents the native Julia
compilation pipeline. The `NativeInterpreter` contains within it all
the compiler parameters previously contained within `Params`, split into
two pieces: `InferenceParams` and `OptimizationParams`, used within type
inference and optimization, respectively. The interpreter object is
then threaded throughout most of the type inference pipeline, and allows
for straightforward prototyping and replacement of the compiler
internals.
As a simple example of the kind of workflow this enables, I include here
a simple testing script showing how to use this to easily get a list
of the number of times a function is inferred during type inference by
overriding just two functions within the compiler. First, I will define
here some simple methods to make working with inference a bit easier:
```julia
using Core.Compiler
import Core.Compiler: InferenceParams, OptimizationParams, get_world_counter, get_inference_cache
"""
@infer_function interp foo(1, 2) [show_steps=true] [show_ir=false]
Infer a function call using the given interpreter object, return
the inference object. Set keyword arguments to modify verbosity:
* Set `show_steps` to `true` to see the `InferenceResult` step by step.
* Set `show_ir` to `true` to see the final type-inferred Julia IR.
"""
macro infer_function(interp, func_call, kwarg_exs...)
if !isa(func_call, Expr) || func_call.head != :call
error("@infer_function requires a function call")
end
local func = func_call.args[1]
local args = func_call.args[2:end]
kwargs = []
for ex in kwarg_exs
if ex isa Expr && ex.head === :(=) && ex.args[1] isa Symbol
push!(kwargs, first(ex.args) => last(ex.args))
else
error("Invalid @infer_function kwarg $(ex)")
end
end
return quote
infer_function($(esc(interp)), $(esc(func)), typeof.(($(args)...,)); $(esc(kwargs))...)
end
end
function infer_function(interp, f, tt; show_steps::Bool=false, show_ir::Bool=false)
# Find all methods that are applicable to these types
fms = methods(f, tt)
if length(fms) != 1
error("Unable to find single applicable method for $f with types $tt")
end
# Take the first applicable method
method = first(fms)
# Build argument tuple
method_args = Tuple{typeof(f), tt...}
# Grab the appropriate method instance for these types
mi = Core.Compiler.specialize_method(method, method_args, Core.svec())
# Construct InferenceResult to hold the result,
result = Core.Compiler.InferenceResult(mi)
if show_steps
@info("Initial result, before inference: ", result)
end
# Create an InferenceState to begin inference, give it a world that is always newest
world = Core.Compiler.get_world_counter()
frame = Core.Compiler.InferenceState(result, #=cached=# true, interp)
# Run type inference on this frame. Because the interpreter is embedded
# within this InferenceResult, we don't need to pass the interpreter in.
Core.Compiler.typeinf_local(interp, frame)
if show_steps
@info("Ending result, post-inference: ", result)
end
if show_ir
@info("Inferred source: ", result.result.src)
end
# Give the result back
return result
end
```
Next, we define a simple function and pass it through:
```julia
function foo(x, y)
return x + y * x
end
native_interpreter = Core.Compiler.NativeInterpreter()
inferred = @infer_function native_interpreter foo(1.0, 2.0) show_steps=true show_ir=true
```
This gives a nice output such as the following:
```julia-repl
┌ Info: Initial result, before inference:
└ result = foo(::Float64, ::Float64) => Any
┌ Info: Ending result, post-inference:
└ result = foo(::Float64, ::Float64) => Float64
┌ Info: Inferred source:
│ result.result.src =
│ CodeInfo(
│ @ REPL[1]:3 within `foo'
│ 1 ─ %1 = (y * x)::Float64
│ │ %2 = (x + %1)::Float64
│ └── return %2
└ )
```
We can then define a custom `AbstractInterpreter` subtype that will
override two specific pieces of the compilation process; managing the
runtime inference cache. While it will transparently pass all information
through to a bundled `NativeInterpreter`, it has the ability to force cache
misses in order to re-infer things so that we can easily see how many
methods (and which) would be inferred to compile a certain method:
```julia
struct CountingInterpreter <: Compiler.AbstractInterpreter
visited_methods::Set{Core.Compiler.MethodInstance}
methods_inferred::Ref{UInt64}
# Keep around a native interpreter so that we can sub off to "super" functions
native_interpreter::Core.Compiler.NativeInterpreter
end
CountingInterpreter() = CountingInterpreter(
Set{Core.Compiler.MethodInstance}(),
Ref(UInt64(0)),
Core.Compiler.NativeInterpreter(),
)
InferenceParams(ci::CountingInterpreter) = InferenceParams(ci.native_interpreter)
OptimizationParams(ci::CountingInterpreter) = OptimizationParams(ci.native_interpreter)
get_world_counter(ci::CountingInterpreter) = get_world_counter(ci.native_interpreter)
get_inference_cache(ci::CountingInterpreter) = get_inference_cache(ci.native_interpreter)
function Core.Compiler.inf_for_methodinstance(interp::CountingInterpreter, mi::Core.Compiler.MethodInstance, min_world::UInt, max_world::UInt=min_world)
# Hit our own cache; if it exists, pass on to the main runtime
if mi in interp.visited_methods
return Core.Compiler.inf_for_methodinstance(interp.native_interpreter, mi, min_world, max_world)
end
# Otherwise, we return `nothing`, forcing a cache miss
return nothing
end
function Core.Compiler.cache_result(interp::CountingInterpreter, result::Core.Compiler.InferenceResult, min_valid::UInt, max_valid::UInt)
push!(interp.visited_methods, result.linfo)
interp.methods_inferred[] += 1
return Core.Compiler.cache_result(interp.native_interpreter, result, min_valid, max_valid)
end
function reset!(interp::CountingInterpreter)
empty!(interp.visited_methods)
interp.methods_inferred[] = 0
return nothing
end
```
Running it on our testing function:
```julia
counting_interpreter = CountingInterpreter()
inferred = @infer_function counting_interpreter foo(1.0, 2.0)
@info("Cumulative number of methods inferred: $(counting_interpreter.methods_inferred[])")
inferred = @infer_function counting_interpreter foo(1, 2) show_ir=true
@info("Cumulative number of methods inferred: $(counting_interpreter.methods_inferred[])")
inferred = @infer_function counting_interpreter foo(1.0, 2.0)
@info("Cumulative number of methods inferred: $(counting_interpreter.methods_inferred[])")
reset!(counting_interpreter)
@info("Cumulative number of methods inferred: $(counting_interpreter.methods_inferred[])")
inferred = @infer_function counting_interpreter foo(1.0, 2.0)
@info("Cumulative number of methods inferred: $(counting_interpreter.methods_inferred[])")
```
Also gives us a nice result:
```
[ Info: Cumulative number of methods inferred: 2
┌ Info: Inferred source:
│ result.result.src =
│ CodeInfo(
│ @ /Users/sabae/src/julia-compilerhack/AbstractInterpreterTest.jl:81 within `foo'
│ 1 ─ %1 = (y * x)::Int64
│ │ %2 = (x + %1)::Int64
│ └── return %2
└ )
[ Info: Cumulative number of methods inferred: 4
[ Info: Cumulative number of methods inferred: 4
[ Info: Cumulative number of methods inferred: 0
[ Info: Cumulative number of methods inferred: 2
```
0 commit comments