diff --git a/src/stage2/interpreter.jl b/src/stage2/interpreter.jl index 17b51eba..93de83d7 100644 --- a/src/stage2/interpreter.jl +++ b/src/stage2/interpreter.jl @@ -55,8 +55,11 @@ struct ADInterpreter <: AbstractInterpreter unopt::Union{OffsetVector{UnoptCache},Nothing} transformed::OffsetVector{OptCache} + # Cache results for forward inference over a converged inference (current_level == missing) + generic::OptCache + native_interpreter::NativeInterpreter - current_level::Int + current_level::Union{Int, Missing} remarks::OffsetVector{RemarksCache} function _ADInterpreter() @@ -66,6 +69,7 @@ struct ADInterpreter <: AbstractInterpreter #=opt::OffsetVector{OptCache}=#OffsetVector([OptCache(), OptCache()], 0:1), #=unopt::Union{OffsetVector{UnoptCache},Nothing}=#OffsetVector([UnoptCache(), UnoptCache()], 0:1), #=transformed::OffsetVector{OptCache}=#OffsetVector([OptCache(), OptCache()], 0:1), + OptCache(), #=native_interpreter::NativeInterpreter=#NativeInterpreter(), #=current_level::Int=#0, #=remarks::OffsetVector{RemarksCache}=#OffsetVector([RemarksCache()], 0:0)) @@ -76,10 +80,11 @@ struct ADInterpreter <: AbstractInterpreter opt::OffsetVector{OptCache} = interp.opt, unopt::Union{OffsetVector{UnoptCache},Nothing} = interp.unopt, transformed::OffsetVector{OptCache} = interp.transformed, + generic::OptCache = interp.generic, native_interpreter::NativeInterpreter = interp.native_interpreter, - current_level::Int = interp.current_level, + current_level::Union{Int, Missing} = interp.current_level, remarks::OffsetVector{RemarksCache} = interp.remarks) - return new(forward, backward, opt, unopt, transformed, native_interpreter, current_level, remarks) + return new(forward, backward, opt, unopt, transformed, generic, native_interpreter, current_level, remarks) end end @@ -89,13 +94,39 @@ lower_level(interp::ADInterpreter) = change_level(interp, interp.current_level - disable_forward(interp::ADInterpreter) = ADInterpreter(interp; forward=false) +const GENERIC_INFERENCE_ENABLED = isdefined(CC, :update_bestguess!) + +@static if GENERIC_INFERENCE_ENABLED + +function CC.InferenceState(result::InferenceResult, cache::Symbol, interp::ADInterpreter) + sv = @invoke CC.InferenceState(result::InferenceResult, cache::Symbol, interp::AbstractInterpreter) + sv === nothing && return sv + if interp.current_level === missing + # override initial bestguess + arginfo = ArgInfo(nothing, result.argtypes) + si = StmtInfo(false) + sv.bestguess = CC.abstract_call(interp.native_interpreter, arginfo, si, sv).rt + end + return sv +end + +function CC.update_bestguess!(interp::ADInterpreter, frame::InferenceState, + currstate::CC.VarTable, @nospecialize(rt)) + if interp.current_level === missing + rt = CC.getfield_tfunc(rt, Const(1)) + end + return @invoke CC.update_bestguess!(interp::AbstractInterpreter, frame::InferenceState, + currstate::CC.VarTable, rt::Any) +end + +end # @static if GENERIC_INFERENCE_ENABLED + function Cthulhu.get_optimized_codeinst(interp::ADInterpreter, curs::ADCursor) @show curs (curs.transformed ? interp.transformed : interp.opt)[curs.level][curs.mi] end Cthulhu.AbstractCursor(interp::ADInterpreter, mi::MethodInstance) = ADCursor(0, mi, false) - # This is a lie, but let's clean this up later Cthulhu.can_descend(interp::ADInterpreter, @nospecialize(key), optimize::Bool) = true @@ -335,15 +366,6 @@ function CC.inlining_policy(interp::ADInterpreter, nothing, info::CC.CallInfo, stmt_flag::UInt8, mi::MethodInstance, argtypes::Vector{Any}) end -# TODO remove this overload once https://github.com/JuliaLang/julia/pull/49191 gets merged -function CC.abstract_call_gf_by_type(interp::ADInterpreter, @nospecialize(f), - arginfo::ArgInfo, si::StmtInfo, @nospecialize(atype), - sv::IRInterpretationState, max_methods::Int) - return @invoke CC.abstract_call_gf_by_type(interp::AbstractInterpreter, f::Any, - arginfo::ArgInfo, si::StmtInfo, atype::Any, - sv::CC.AbsIntState, max_methods::Int) -end - #= function CC.optimize(interp::ADInterpreter, opt::OptimizationState, params::OptimizationParams, caller::InferenceResult)