@@ -55,8 +55,11 @@ struct ADInterpreter <: AbstractInterpreter
55
55
unopt:: Union{OffsetVector{UnoptCache},Nothing}
56
56
transformed:: OffsetVector{OptCache}
57
57
58
+ # Cache results for forward inference over a converged inference (current_level == missing)
59
+ generic:: OptCache
60
+
58
61
native_interpreter:: NativeInterpreter
59
- current_level:: Int
62
+ current_level:: Union{ Int, Missing}
60
63
remarks:: OffsetVector{RemarksCache}
61
64
62
65
function _ADInterpreter ()
@@ -66,6 +69,7 @@ struct ADInterpreter <: AbstractInterpreter
66
69
#= opt::OffsetVector{OptCache}=# OffsetVector ([OptCache (), OptCache ()], 0 : 1 ),
67
70
#= unopt::Union{OffsetVector{UnoptCache},Nothing}=# OffsetVector ([UnoptCache (), UnoptCache ()], 0 : 1 ),
68
71
#= transformed::OffsetVector{OptCache}=# OffsetVector ([OptCache (), OptCache ()], 0 : 1 ),
72
+ OptCache (),
69
73
#= native_interpreter::NativeInterpreter=# NativeInterpreter (),
70
74
#= current_level::Int=# 0 ,
71
75
#= remarks::OffsetVector{RemarksCache}=# OffsetVector ([RemarksCache ()], 0 : 0 ))
@@ -76,10 +80,11 @@ struct ADInterpreter <: AbstractInterpreter
76
80
opt:: OffsetVector{OptCache} = interp. opt,
77
81
unopt:: Union{OffsetVector{UnoptCache},Nothing} = interp. unopt,
78
82
transformed:: OffsetVector{OptCache} = interp. transformed,
83
+ generic:: OptCache = interp. generic,
79
84
native_interpreter:: NativeInterpreter = interp. native_interpreter,
80
- current_level:: Int = interp. current_level,
85
+ current_level:: Union{ Int, Missing} = interp. current_level,
81
86
remarks:: OffsetVector{RemarksCache} = interp. remarks)
82
- return new (forward, backward, opt, unopt, transformed, native_interpreter, current_level, remarks)
87
+ return new (forward, backward, opt, unopt, transformed, generic, native_interpreter, current_level, remarks)
83
88
end
84
89
end
85
90
@@ -89,6 +94,27 @@ lower_level(interp::ADInterpreter) = change_level(interp, interp.current_level -
89
94
90
95
disable_forward (interp:: ADInterpreter ) = ADInterpreter (interp; forward= false )
91
96
97
+ function CC. InferenceState (result:: InferenceResult , cache:: Symbol , interp:: ADInterpreter )
98
+ if interp. current_level === missing
99
+ error ()
100
+ end
101
+ return @invoke CC. InferenceState (result:: InferenceResult , cache:: Symbol , interp:: AbstractInterpreter )
102
+ # prepare an InferenceState object for inferring lambda
103
+ world = get_world_counter (interp)
104
+ src = retrieve_code_info (result. linfo, world)
105
+ src === nothing && return nothing
106
+ validate_code_in_debug_mode (result. linfo, src, " lowered" )
107
+ return InferenceState (result, src, cache, interp, Bottom)
108
+ end
109
+
110
+
111
+ function CC. initial_bestguess (interp:: ADInterpreter , result:: InferenceResult )
112
+ if interp. current_level === missing
113
+ return CC. typeinf_lattice (interp. native_interpreter, result. linfo)
114
+ end
115
+ return Bottom
116
+ end
117
+
92
118
function Cthulhu. get_optimized_codeinst (interp:: ADInterpreter , curs:: ADCursor )
93
119
@show curs
94
120
(curs. transformed ? interp. transformed : interp. opt)[curs. level][curs. mi]
@@ -335,15 +361,6 @@ function CC.inlining_policy(interp::ADInterpreter,
335
361
nothing , info:: CC.CallInfo , stmt_flag:: UInt8 , mi:: MethodInstance , argtypes:: Vector{Any} )
336
362
end
337
363
338
- # TODO remove this overload once https://github.com/JuliaLang/julia/pull/49191 gets merged
339
- function CC. abstract_call_gf_by_type (interp:: ADInterpreter , @nospecialize (f),
340
- arginfo:: ArgInfo , si:: StmtInfo , @nospecialize (atype),
341
- sv:: IRInterpretationState , max_methods:: Int )
342
- return @invoke CC. abstract_call_gf_by_type (interp:: AbstractInterpreter , f:: Any ,
343
- arginfo:: ArgInfo , si:: StmtInfo , atype:: Any ,
344
- sv:: CC.AbsIntState , max_methods:: Int )
345
- end
346
-
347
364
#=
348
365
function CC.optimize(interp::ADInterpreter, opt::OptimizationState,
349
366
params::OptimizationParams, caller::InferenceResult)
0 commit comments