Skip to content

Commit cefe938

Browse files
Kenoaviatesk
authored andcommitted
WIP
1 parent 9b2a506 commit cefe938

File tree

1 file changed

+29
-12
lines changed

1 file changed

+29
-12
lines changed

src/stage2/interpreter.jl

Lines changed: 29 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -55,8 +55,11 @@ struct ADInterpreter <: AbstractInterpreter
5555
unopt::Union{OffsetVector{UnoptCache},Nothing}
5656
transformed::OffsetVector{OptCache}
5757

58+
# Cache results for forward inference over a converged inference (current_level == missing)
59+
generic::OptCache
60+
5861
native_interpreter::NativeInterpreter
59-
current_level::Int
62+
current_level::Union{Int, Missing}
6063
remarks::OffsetVector{RemarksCache}
6164

6265
function _ADInterpreter()
@@ -66,6 +69,7 @@ struct ADInterpreter <: AbstractInterpreter
6669
#=opt::OffsetVector{OptCache}=#OffsetVector([OptCache(), OptCache()], 0:1),
6770
#=unopt::Union{OffsetVector{UnoptCache},Nothing}=#OffsetVector([UnoptCache(), UnoptCache()], 0:1),
6871
#=transformed::OffsetVector{OptCache}=#OffsetVector([OptCache(), OptCache()], 0:1),
72+
OptCache(),
6973
#=native_interpreter::NativeInterpreter=#NativeInterpreter(),
7074
#=current_level::Int=#0,
7175
#=remarks::OffsetVector{RemarksCache}=#OffsetVector([RemarksCache()], 0:0))
@@ -76,10 +80,11 @@ struct ADInterpreter <: AbstractInterpreter
7680
opt::OffsetVector{OptCache} = interp.opt,
7781
unopt::Union{OffsetVector{UnoptCache},Nothing} = interp.unopt,
7882
transformed::OffsetVector{OptCache} = interp.transformed,
83+
generic::OptCache = interp.generic,
7984
native_interpreter::NativeInterpreter = interp.native_interpreter,
80-
current_level::Int = interp.current_level,
85+
current_level::Union{Int, Missing} = interp.current_level,
8186
remarks::OffsetVector{RemarksCache} = interp.remarks)
82-
return new(forward, backward, opt, unopt, transformed, native_interpreter, current_level, remarks)
87+
return new(forward, backward, opt, unopt, transformed, generic, native_interpreter, current_level, remarks)
8388
end
8489
end
8590

@@ -89,6 +94,27 @@ lower_level(interp::ADInterpreter) = change_level(interp, interp.current_level -
8994

9095
disable_forward(interp::ADInterpreter) = ADInterpreter(interp; forward=false)
9196

97+
function CC.InferenceState(result::InferenceResult, cache::Symbol, interp::ADInterpreter)
98+
if interp.current_level === missing
99+
error()
100+
end
101+
return @invoke CC.InferenceState(result::InferenceResult, cache::Symbol, interp::AbstractInterpreter)
102+
# prepare an InferenceState object for inferring lambda
103+
world = get_world_counter(interp)
104+
src = retrieve_code_info(result.linfo, world)
105+
src === nothing && return nothing
106+
validate_code_in_debug_mode(result.linfo, src, "lowered")
107+
return InferenceState(result, src, cache, interp, Bottom)
108+
end
109+
110+
111+
function CC.initial_bestguess(interp::ADInterpreter, result::InferenceResult)
112+
if interp.current_level === missing
113+
return CC.typeinf_lattice(interp.native_interpreter, result.linfo)
114+
end
115+
return Bottom
116+
end
117+
92118
function Cthulhu.get_optimized_codeinst(interp::ADInterpreter, curs::ADCursor)
93119
@show curs
94120
(curs.transformed ? interp.transformed : interp.opt)[curs.level][curs.mi]
@@ -335,15 +361,6 @@ function CC.inlining_policy(interp::ADInterpreter,
335361
nothing, info::CC.CallInfo, stmt_flag::UInt8, mi::MethodInstance, argtypes::Vector{Any})
336362
end
337363

338-
# TODO remove this overload once https://github.com/JuliaLang/julia/pull/49191 gets merged
339-
function CC.abstract_call_gf_by_type(interp::ADInterpreter, @nospecialize(f),
340-
arginfo::ArgInfo, si::StmtInfo, @nospecialize(atype),
341-
sv::IRInterpretationState, max_methods::Int)
342-
return @invoke CC.abstract_call_gf_by_type(interp::AbstractInterpreter, f::Any,
343-
arginfo::ArgInfo, si::StmtInfo, atype::Any,
344-
sv::CC.AbsIntState, max_methods::Int)
345-
end
346-
347364
#=
348365
function CC.optimize(interp::ADInterpreter, opt::OptimizationState,
349366
params::OptimizationParams, caller::InferenceResult)

0 commit comments

Comments
 (0)