Skip to content

AbstractInterpreter: add a hook to customize bestguess calculation #50744

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Aug 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 37 additions & 30 deletions base/compiler/abstractinterpretation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2887,17 +2887,49 @@ function init_vartable!(vartable::VarTable, frame::InferenceState)
return vartable
end

function update_bestguess!(interp::AbstractInterpreter, frame::InferenceState,
currstate::VarTable, @nospecialize(rt))
bestguess = frame.bestguess
nargs = narguments(frame, #=include_va=#false)
slottypes = frame.slottypes
rt = widenreturn(rt, BestguessInfo(interp, bestguess, nargs, slottypes, currstate))
# narrow representation of bestguess slightly to prepare for tmerge with rt
if rt isa InterConditional && bestguess isa Const
slot_id = rt.slot
old_id_type = slottypes[slot_id]
if bestguess.val === true && rt.elsetype !== Bottom
bestguess = InterConditional(slot_id, old_id_type, Bottom)
elseif bestguess.val === false && rt.thentype !== Bottom
bestguess = InterConditional(slot_id, Bottom, old_id_type)
end
end
# copy limitations to return value
if !isempty(frame.pclimitations)
union!(frame.limitations, frame.pclimitations)
empty!(frame.pclimitations)
end
if !isempty(frame.limitations)
rt = LimitedAccuracy(rt, copy(frame.limitations))
end
𝕃ₚ = ipo_lattice(interp)
if !⊑(𝕃ₚ, rt, bestguess)
# TODO: if bestguess isa InterConditional && !interesting(bestguess); bestguess = widenconditional(bestguess); end
frame.bestguess = tmerge(𝕃ₚ, bestguess, rt) # new (wider) return type for frame
return true
else
return false
end
end

# make as much progress on `frame` as possible (without handling cycles)
function typeinf_local(interp::AbstractInterpreter, frame::InferenceState)
@assert !is_inferred(frame)
frame.dont_work_on_me = true # mark that this function is currently on the stack
W = frame.ip
nargs = narguments(frame, #=include_va=#false)
slottypes = frame.slottypes
ssavaluetypes = frame.ssavaluetypes
bbs = frame.cfg.blocks
nbbs = length(bbs)
𝕃ₚ, 𝕃ᵢ = ipo_lattice(interp), typeinf_lattice(interp)
𝕃ᵢ = typeinf_lattice(interp)

currbb = frame.currbb
if currbb != 1
Expand Down Expand Up @@ -2998,35 +3030,10 @@ function typeinf_local(interp::AbstractInterpreter, frame::InferenceState)
end
end
elseif isa(stmt, ReturnNode)
bestguess = frame.bestguess
rt = abstract_eval_value(interp, stmt.val, currstate, frame)
rt = widenreturn(rt, BestguessInfo(interp, bestguess, nargs, slottypes, currstate))
# narrow representation of bestguess slightly to prepare for tmerge with rt
if rt isa InterConditional && bestguess isa Const
let slot_id = rt.slot
old_id_type = slottypes[slot_id]
if bestguess.val === true && rt.elsetype !== Bottom
bestguess = InterConditional(slot_id, old_id_type, Bottom)
elseif bestguess.val === false && rt.thentype !== Bottom
bestguess = InterConditional(slot_id, Bottom, old_id_type)
end
end
end
# copy limitations to return value
if !isempty(frame.pclimitations)
union!(frame.limitations, frame.pclimitations)
empty!(frame.pclimitations)
end
if !isempty(frame.limitations)
rt = LimitedAccuracy(rt, copy(frame.limitations))
end
if !⊑(𝕃ₚ, rt, bestguess)
# new (wider) return type for frame
bestguess = tmerge(𝕃ₚ, bestguess, rt)
# TODO: if bestguess isa InterConditional && !interesting(bestguess); bestguess = widenconditional(bestguess); end
frame.bestguess = bestguess
if update_bestguess!(interp, frame, currstate, rt)
for (caller, caller_pc) in frame.cycle_backedges
if !(caller.ssavaluetypes[caller_pc] === Any)
if caller.ssavaluetypes[caller_pc] !== Any
# no reason to revisit if that call-site doesn't affect the final result
push!(caller.ip, block_for_inst(caller.cfg, caller_pc))
end
Expand Down
39 changes: 21 additions & 18 deletions base/compiler/typeinfer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -870,26 +870,10 @@ function typeinf_edge(interp::AbstractInterpreter, method::Method, @nospecialize
# since the inliner will request to use it later
cache = :local
else
rt = cached_return_type(code)
effects = ipo_effects(code)
update_valid_age!(caller, WorldRange(min_world(code), max_world(code)))
rettype = code.rettype
if isdefined(code, :rettype_const)
rettype_const = code.rettype_const
# the second subtyping/egal conditions are necessary to distinguish usual cases
# from rare cases when `Const` wrapped those extended lattice type objects
if isa(rettype_const, Vector{Any}) && !(Vector{Any} <: rettype)
rettype = PartialStruct(rettype, rettype_const)
elseif isa(rettype_const, PartialOpaque) && rettype <: Core.OpaqueClosure
rettype = rettype_const
elseif isa(rettype_const, InterConditional) && rettype !== InterConditional
rettype = rettype_const
elseif isa(rettype_const, InterMustAlias) && rettype !== InterMustAlias
rettype = rettype_const
else
rettype = Const(rettype_const)
end
end
return EdgeCallResult(rettype, mi, effects)
return EdgeCallResult(rt, mi, effects)
end
else
cache = :global # cache edge targets by default
Expand Down Expand Up @@ -933,6 +917,25 @@ function typeinf_edge(interp::AbstractInterpreter, method::Method, @nospecialize
return EdgeCallResult(frame.bestguess, nothing, adjust_effects(frame))
end

function cached_return_type(code::CodeInstance)
rettype = code.rettype
isdefined(code, :rettype_const) || return rettype
rettype_const = code.rettype_const
# the second subtyping/egal conditions are necessary to distinguish usual cases
# from rare cases when `Const` wrapped those extended lattice type objects
if isa(rettype_const, Vector{Any}) && !(Vector{Any} <: rettype)
return PartialStruct(rettype, rettype_const)
elseif isa(rettype_const, PartialOpaque) && rettype <: Core.OpaqueClosure
return rettype_const
elseif isa(rettype_const, InterConditional) && rettype !== InterConditional
return rettype_const
elseif isa(rettype_const, InterMustAlias) && rettype !== InterMustAlias
return rettype_const
else
return Const(rettype_const)
end
end

#### entry points for inferring a MethodInstance given a type signature ####

# compute an inferred AST and return type
Expand Down