From cefe9386242f016463cebe53f4f6e70c75d3d52f Mon Sep 17 00:00:00 2001 From: Keno Fischer Date: Mon, 31 Jul 2023 20:16:45 +0000 Subject: [PATCH 1/4] WIP --- src/stage2/interpreter.jl | 41 +++++++++++++++++++++++++++------------ 1 file changed, 29 insertions(+), 12 deletions(-) diff --git a/src/stage2/interpreter.jl b/src/stage2/interpreter.jl index 17b51eba..303d93df 100644 --- a/src/stage2/interpreter.jl +++ b/src/stage2/interpreter.jl @@ -55,8 +55,11 @@ struct ADInterpreter <: AbstractInterpreter unopt::Union{OffsetVector{UnoptCache},Nothing} transformed::OffsetVector{OptCache} + # Cache results for forward inference over a converged inference (current_level == missing) + generic::OptCache + native_interpreter::NativeInterpreter - current_level::Int + current_level::Union{Int, Missing} remarks::OffsetVector{RemarksCache} function _ADInterpreter() @@ -66,6 +69,7 @@ struct ADInterpreter <: AbstractInterpreter #=opt::OffsetVector{OptCache}=#OffsetVector([OptCache(), OptCache()], 0:1), #=unopt::Union{OffsetVector{UnoptCache},Nothing}=#OffsetVector([UnoptCache(), UnoptCache()], 0:1), #=transformed::OffsetVector{OptCache}=#OffsetVector([OptCache(), OptCache()], 0:1), + OptCache(), #=native_interpreter::NativeInterpreter=#NativeInterpreter(), #=current_level::Int=#0, #=remarks::OffsetVector{RemarksCache}=#OffsetVector([RemarksCache()], 0:0)) @@ -76,10 +80,11 @@ struct ADInterpreter <: AbstractInterpreter opt::OffsetVector{OptCache} = interp.opt, unopt::Union{OffsetVector{UnoptCache},Nothing} = interp.unopt, transformed::OffsetVector{OptCache} = interp.transformed, + generic::OptCache = interp.generic, native_interpreter::NativeInterpreter = interp.native_interpreter, - current_level::Int = interp.current_level, + current_level::Union{Int, Missing} = interp.current_level, remarks::OffsetVector{RemarksCache} = interp.remarks) - return new(forward, backward, opt, unopt, transformed, native_interpreter, current_level, remarks) + return new(forward, backward, opt, unopt, transformed, generic, native_interpreter, current_level, remarks) end end @@ -89,6 +94,27 @@ lower_level(interp::ADInterpreter) = change_level(interp, interp.current_level - disable_forward(interp::ADInterpreter) = ADInterpreter(interp; forward=false) +function CC.InferenceState(result::InferenceResult, cache::Symbol, interp::ADInterpreter) + if interp.current_level === missing + error() + end + return @invoke CC.InferenceState(result::InferenceResult, cache::Symbol, interp::AbstractInterpreter) + # prepare an InferenceState object for inferring lambda + world = get_world_counter(interp) + src = retrieve_code_info(result.linfo, world) + src === nothing && return nothing + validate_code_in_debug_mode(result.linfo, src, "lowered") + return InferenceState(result, src, cache, interp, Bottom) +end + + +function CC.initial_bestguess(interp::ADInterpreter, result::InferenceResult) + if interp.current_level === missing + return CC.typeinf_lattice(interp.native_interpreter, result.linfo) + end + return Bottom +end + function Cthulhu.get_optimized_codeinst(interp::ADInterpreter, curs::ADCursor) @show curs (curs.transformed ? interp.transformed : interp.opt)[curs.level][curs.mi] @@ -335,15 +361,6 @@ function CC.inlining_policy(interp::ADInterpreter, nothing, info::CC.CallInfo, stmt_flag::UInt8, mi::MethodInstance, argtypes::Vector{Any}) end -# TODO remove this overload once https://github.com/JuliaLang/julia/pull/49191 gets merged -function CC.abstract_call_gf_by_type(interp::ADInterpreter, @nospecialize(f), - arginfo::ArgInfo, si::StmtInfo, @nospecialize(atype), - sv::IRInterpretationState, max_methods::Int) - return @invoke CC.abstract_call_gf_by_type(interp::AbstractInterpreter, f::Any, - arginfo::ArgInfo, si::StmtInfo, atype::Any, - sv::CC.AbsIntState, max_methods::Int) -end - #= function CC.optimize(interp::ADInterpreter, opt::OptimizationState, params::OptimizationParams, caller::InferenceResult) From 607ba1875553ca503ead223a290f79386ba72a8f Mon Sep 17 00:00:00 2001 From: Shuhei Kadowaki Date: Mon, 31 Jul 2023 16:50:20 -0400 Subject: [PATCH 2/4] wip --- src/stage2/interpreter.jl | 25 ++++++++++++------------- 1 file changed, 12 insertions(+), 13 deletions(-) diff --git a/src/stage2/interpreter.jl b/src/stage2/interpreter.jl index 303d93df..853c4a68 100644 --- a/src/stage2/interpreter.jl +++ b/src/stage2/interpreter.jl @@ -95,24 +95,24 @@ lower_level(interp::ADInterpreter) = change_level(interp, interp.current_level - disable_forward(interp::ADInterpreter) = ADInterpreter(interp; forward=false) function CC.InferenceState(result::InferenceResult, cache::Symbol, interp::ADInterpreter) + sv = @invoke CC.InferenceState(result::InferenceResult, cache::Symbol, interp::AbstractInterpreter) + sv === nothing && return sv if interp.current_level === missing - error() + # override initial bestguess + arginfo = ArgInfo(nothing, result.argtypes) + si = StmtInfo(true) + sv.bestguess = CC.abstract_call(interp.native_interpreter, arginfo, si, sv).rt end - return @invoke CC.InferenceState(result::InferenceResult, cache::Symbol, interp::AbstractInterpreter) - # prepare an InferenceState object for inferring lambda - world = get_world_counter(interp) - src = retrieve_code_info(result.linfo, world) - src === nothing && return nothing - validate_code_in_debug_mode(result.linfo, src, "lowered") - return InferenceState(result, src, cache, interp, Bottom) + return sv end - -function CC.initial_bestguess(interp::ADInterpreter, result::InferenceResult) +function CC.update_bestguess!(interp::ADInterpreter, frame::InferenceState, + currstate::CC.VarTable, @nospecialize(rt)) if interp.current_level === missing - return CC.typeinf_lattice(interp.native_interpreter, result.linfo) + rt = CC.getfield_tfunc(rt, Const(1)) end - return Bottom + return @invoke CC.update_bestguess!(interp::AbstractInterpreter, frame::InferenceState, + currstate::CC.VarTable, rt::Any) end function Cthulhu.get_optimized_codeinst(interp::ADInterpreter, curs::ADCursor) @@ -121,7 +121,6 @@ function Cthulhu.get_optimized_codeinst(interp::ADInterpreter, curs::ADCursor) end Cthulhu.AbstractCursor(interp::ADInterpreter, mi::MethodInstance) = ADCursor(0, mi, false) - # This is a lie, but let's clean this up later Cthulhu.can_descend(interp::ADInterpreter, @nospecialize(key), optimize::Bool) = true From 0362267013baeacc1f845ab9b2e27093a684f17a Mon Sep 17 00:00:00 2001 From: Shuhei Kadowaki <40514306+aviatesk@users.noreply.github.com> Date: Tue, 1 Aug 2023 06:23:49 +0900 Subject: [PATCH 3/4] Update src/stage2/interpreter.jl --- src/stage2/interpreter.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/stage2/interpreter.jl b/src/stage2/interpreter.jl index 853c4a68..18023962 100644 --- a/src/stage2/interpreter.jl +++ b/src/stage2/interpreter.jl @@ -100,7 +100,7 @@ function CC.InferenceState(result::InferenceResult, cache::Symbol, interp::ADInt if interp.current_level === missing # override initial bestguess arginfo = ArgInfo(nothing, result.argtypes) - si = StmtInfo(true) + si = StmtInfo(false) sv.bestguess = CC.abstract_call(interp.native_interpreter, arginfo, si, sv).rt end return sv From 2438a4777d44a0557c030eb3d3f606f33d0a315d Mon Sep 17 00:00:00 2001 From: Shuhei Kadowaki Date: Mon, 31 Jul 2023 18:01:13 -0400 Subject: [PATCH 4/4] fix on 1.10 --- src/stage2/interpreter.jl | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/stage2/interpreter.jl b/src/stage2/interpreter.jl index 18023962..93de83d7 100644 --- a/src/stage2/interpreter.jl +++ b/src/stage2/interpreter.jl @@ -94,6 +94,10 @@ lower_level(interp::ADInterpreter) = change_level(interp, interp.current_level - disable_forward(interp::ADInterpreter) = ADInterpreter(interp; forward=false) +const GENERIC_INFERENCE_ENABLED = isdefined(CC, :update_bestguess!) + +@static if GENERIC_INFERENCE_ENABLED + function CC.InferenceState(result::InferenceResult, cache::Symbol, interp::ADInterpreter) sv = @invoke CC.InferenceState(result::InferenceResult, cache::Symbol, interp::AbstractInterpreter) sv === nothing && return sv @@ -115,6 +119,8 @@ function CC.update_bestguess!(interp::ADInterpreter, frame::InferenceState, currstate::CC.VarTable, rt::Any) end +end # @static if GENERIC_INFERENCE_ENABLED + function Cthulhu.get_optimized_codeinst(interp::ADInterpreter, curs::ADCursor) @show curs (curs.transformed ? interp.transformed : interp.opt)[curs.level][curs.mi]