Skip to content

Commit 47130c5

Browse files
vtjnashstaticfloat
authored andcommitted
inference: stop re-converging worlds after optimization (#38820)
The validity did not change, so we should not need to update it. This also ensures we copy over all result information earlier, so we can destroy the InferenceState slightly sooner, and slightly cleaner data flow. (cherry picked from commit 8c01444)
1 parent 951d1b3 commit 47130c5

File tree

4 files changed

+55
-45
lines changed

4 files changed

+55
-45
lines changed

base/compiler/compiler.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -103,11 +103,10 @@ using .Sort
103103
# compiler #
104104
############
105105

106+
include("compiler/cicache.jl")
106107
include("compiler/types.jl")
107108
include("compiler/utilities.jl")
108109
include("compiler/validation.jl")
109-
110-
include("compiler/cicache.jl")
111110
include("compiler/methodtable.jl")
112111

113112
include("compiler/inferenceresult.jl")

base/compiler/optimize.jl

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -44,20 +44,15 @@ mutable struct OptimizationState
4444
const_api::Bool
4545
inlining::InliningState
4646
function OptimizationState(frame::InferenceState, params::OptimizationParams, interp::AbstractInterpreter)
47-
s_edges = frame.stmt_edges[1]
48-
if s_edges === nothing
49-
s_edges = []
50-
frame.stmt_edges[1] = s_edges
51-
end
52-
src = frame.src
47+
s_edges = frame.stmt_edges[1]::Vector{Any}
5348
inlining = InliningState(params,
54-
EdgeTracker(s_edges::Vector{Any}, frame.valid_worlds),
49+
EdgeTracker(s_edges, frame.valid_worlds),
5550
InferenceCaches(
5651
get_inference_cache(interp),
5752
WorldView(code_cache(interp), frame.world)),
5853
method_table(interp))
5954
return new(frame.linfo,
60-
src, frame.stmt_info, frame.mod, frame.nargs,
55+
frame.src, frame.stmt_info, frame.mod, frame.nargs,
6156
frame.sptypes, frame.slottypes, false,
6257
inlining)
6358
end

base/compiler/typeinfer.jl

Lines changed: 49 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -217,21 +217,29 @@ function _typeinf(interp::AbstractInterpreter, frame::InferenceState)
217217
# with no active ip's, frame is done
218218
frames = frame.callers_in_cycle
219219
isempty(frames) && push!(frames, frame)
220+
valid_worlds = WorldRange()
220221
for caller in frames
221222
@assert !(caller.dont_work_on_me)
222223
caller.dont_work_on_me = true
224+
# might might not fully intersect these earlier, so do that now
225+
valid_worlds = intersect(caller.valid_worlds, valid_worlds)
223226
end
224227
for caller in frames
228+
caller.valid_worlds = valid_worlds
225229
finish(caller, interp)
230+
# finalize and record the linfo result
231+
caller.inferred = true
226232
end
227233
# collect results for the new expanded frame
228-
results = Tuple{InferenceResult, Bool}[ ( frames[i].result,
229-
frames[i].cached || frames[i].parent !== nothing ) for i in 1:length(frames) ]
230-
# empty!(frames)
231-
valid_worlds = frame.valid_worlds
234+
results = Tuple{InferenceResult, Vector{Any}, Bool}[
235+
( frames[i].result,
236+
frames[i].stmt_edges[1],
237+
frames[i].cached || frames[i].parent !== nothing )
238+
for i in 1:length(frames) ]
239+
empty!(frames)
232240
cached = frame.cached
233241
if cached || frame.parent !== nothing
234-
for (caller, doopt) in results
242+
for (caller, _, doopt) in results
235243
opt = caller.src
236244
if opt isa OptimizationState
237245
run_optimizer = doopt && may_optimize(interp)
@@ -253,31 +261,24 @@ function _typeinf(interp::AbstractInterpreter, frame::InferenceState)
253261
caller.src = nothing
254262
end
255263
end
256-
# As a hack the et reuses frame_edges[1] to push any optimization
257-
# edges into, so we don't need to handle them specially here
258-
valid_worlds = intersect(valid_worlds, opt.inlining.et.valid_worlds[])
264+
caller.valid_worlds = opt.inlining.et.valid_worlds[]
259265
end
260266
end
261267
end
262-
if last(valid_worlds) == get_world_counter()
263-
valid_worlds = WorldRange(first(valid_worlds), typemax(UInt))
264-
end
265-
for caller in frames
268+
for (caller, edges, doopt) in results
269+
valid_worlds = caller.valid_worlds
270+
if last(valid_worlds) == get_world_counter()
271+
valid_worlds = WorldRange(first(valid_worlds), typemax(UInt))
272+
end
266273
caller.valid_worlds = valid_worlds
267-
caller.src.min_world = first(valid_worlds)
268-
caller.src.max_world = last(valid_worlds)
269274
if cached
270-
cache_result!(interp, caller.result, valid_worlds)
275+
cache_result!(interp, caller)
271276
end
272-
if last(valid_worlds) == typemax(UInt)
277+
if doopt && last(valid_worlds) == typemax(UInt)
273278
# if we aren't cached, we don't need this edge
274279
# but our caller might, so let's just make it anyways
275-
for caller in frames
276-
store_backedges(caller)
277-
end
280+
store_backedges(caller, edges)
278281
end
279-
# finalize and record the linfo result
280-
caller.inferred = true
281282
end
282283
return true
283284
end
@@ -343,14 +344,16 @@ function maybe_compress_codeinfo(interp::AbstractInterpreter, linfo::MethodInsta
343344
end
344345

345346
function transform_result_for_cache(interp::AbstractInterpreter, linfo::MethodInstance,
346-
@nospecialize(inferred_result))
347+
valid_worlds::WorldRange, @nospecialize(inferred_result))
347348
local const_flags::Int32
348349
# If we decided not to optimize, drop the OptimizationState now.
349350
# External interpreters can override as necessary to cache additional information
350351
if inferred_result isa OptimizationState
351352
inferred_result = inferred_result.src
352353
end
353354
if inferred_result isa CodeInfo
355+
inferred_result.min_world = first(valid_worlds)
356+
inferred_result.max_world = last(valid_worlds)
354357
inferred_result = maybe_compress_codeinfo(interp, linfo, inferred_result)
355358
end
356359
# The global cache can only handle objects that codegen understands
@@ -360,7 +363,8 @@ function transform_result_for_cache(interp::AbstractInterpreter, linfo::MethodIn
360363
return inferred_result
361364
end
362365

363-
function cache_result!(interp::AbstractInterpreter, result::InferenceResult, valid_worlds::WorldRange)
366+
function cache_result!(interp::AbstractInterpreter, result::InferenceResult)
367+
valid_worlds = result.valid_worlds
364368
# check if the existing linfo metadata is also sufficient to describe the current inference result
365369
# to decide if it is worth caching this
366370
already_inferred = already_inferred_quick_test(interp, result.linfo)
@@ -370,7 +374,7 @@ function cache_result!(interp::AbstractInterpreter, result::InferenceResult, val
370374

371375
# TODO: also don't store inferred code if we've previously decided to interpret this function
372376
if !already_inferred
373-
inferred_result = transform_result_for_cache(interp, result.linfo, result.src)
377+
inferred_result = transform_result_for_cache(interp, result.linfo, valid_worlds, result.src)
374378
code_cache(interp)[result.linfo] = CodeInstance(result, inferred_result, valid_worlds)
375379
end
376380
unlock_mi_inference(interp, result.linfo)
@@ -381,6 +385,21 @@ end
381385
# update the MethodInstance
382386
function finish(me::InferenceState, interp::AbstractInterpreter)
383387
# prepare to run optimization passes on fulltree
388+
s_edges = me.stmt_edges[1]
389+
if s_edges === nothing
390+
s_edges = []
391+
me.stmt_edges[1] = s_edges
392+
end
393+
for edges in me.stmt_edges
394+
edges === nothing && continue
395+
edges === s_edges && continue
396+
append!(s_edges, edges)
397+
empty!(edges)
398+
end
399+
if me.src.edges !== nothing
400+
append!(s_edges, me.src.edges)
401+
me.src.edges = nothing
402+
end
384403
if me.limited && me.cached && me.parent !== nothing
385404
# a top parent will be cached still, but not this intermediate work
386405
# we can throw everything else away now
@@ -392,6 +411,7 @@ function finish(me::InferenceState, interp::AbstractInterpreter)
392411
type_annotate!(me)
393412
me.result.src = OptimizationState(me, OptimizationParams(interp), interp)
394413
end
414+
me.result.valid_worlds = me.valid_worlds
395415
me.result.result = me.bestguess
396416
nothing
397417
end
@@ -404,20 +424,15 @@ function finish(src::CodeInfo, interp::AbstractInterpreter)
404424
end
405425

406426
# record the backedges
407-
function store_backedges(frame::InferenceState)
427+
function store_backedges(frame::InferenceResult, edges::Vector{Any})
408428
toplevel = !isa(frame.linfo.def, Method)
409-
if !toplevel && (frame.cached || frame.parent !== nothing)
410-
caller = frame.result.linfo
411-
for edges in frame.stmt_edges
412-
store_backedges(caller, edges)
413-
end
414-
store_backedges(caller, frame.src.edges)
415-
frame.src.edges = nothing
429+
if !toplevel
430+
store_backedges(frame.linfo, edges)
416431
end
432+
nothing
417433
end
418434

419-
store_backedges(caller, edges::Nothing) = nothing
420-
function store_backedges(caller, edges::Vector)
435+
function store_backedges(caller::MethodInstance, edges::Vector)
421436
i = 1
422437
while i <= length(edges)
423438
to = edges[i]

base/compiler/types.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,10 @@ mutable struct InferenceResult
2828
overridden_by_const::BitVector
2929
result # ::Type, or InferenceState if WIP
3030
src #::Union{CodeInfo, OptimizationState, Nothing} # if inferred copy is available
31+
valid_worlds::WorldRange # if inference and optimization is finished
3132
function InferenceResult(linfo::MethodInstance, given_argtypes = nothing)
3233
argtypes, overridden_by_const = matching_cache_argtypes(linfo, given_argtypes)
33-
return new(linfo, argtypes, overridden_by_const, Any, nothing)
34+
return new(linfo, argtypes, overridden_by_const, Any, nothing, WorldRange())
3435
end
3536
end
3637

0 commit comments

Comments
 (0)