Skip to content

Commit 5ffbc97

Browse files
committed
inference: track reaching defs for slots
This change effectively computes the SSA / ϕ-nodes for program slots as part of type-inference, using the "path-convergence criterion" for SSA. This allows us to conveniently reason about slot identity (in typical SSA fashion) without having to quadratically expand all of our SSA type state over the CFG.
1 parent 78b0b74 commit 5ffbc97

File tree

13 files changed

+226
-150
lines changed

13 files changed

+226
-150
lines changed

base/compiler/abstractinterpretation.jl

Lines changed: 130 additions & 89 deletions
Large diffs are not rendered by default.

base/compiler/inferencestate.jl

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -315,11 +315,13 @@ mutable struct InferenceState
315315
nargtypes = length(argtypes)
316316
for i = 1:nslots
317317
argtyp = (i > nargtypes) ? Bottom : argtypes[i]
318+
# 0 = function entry (think carefully)
318319
if argtyp === Bool && has_conditional(typeinf_lattice(interp))
319-
argtyp = Conditional(i, Const(true), Const(false))
320+
argtyp = Conditional(i, #= ssadef =# 0, Const(true), Const(false))
320321
end
321322
slottypes[i] = argtyp
322-
bb_vartable1[i] = VarState(argtyp, i > nargtypes)
323+
# 0 = function entry (think carefully)
324+
bb_vartable1[i] = VarState(argtyp, #= ssadef =# 0, i > nargtypes)
323325
end
324326
src.ssavaluetypes = ssavaluetypes = Any[ NOT_FOUND for i = 1:nssavalues ]
325327

@@ -712,7 +714,7 @@ function sptypes_from_meth_instance(mi::MethodInstance)
712714
ty = Const(v)
713715
undef = false
714716
end
715-
sptypes[i] = VarState(ty, undef)
717+
sptypes[i] = VarState(ty, typemin(Int), undef)
716718
end
717719
return sptypes
718720
end

base/compiler/optimize.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -197,7 +197,7 @@ function OptimizationState(mi::MethodInstance, src::CodeInfo, interp::AbstractIn
197197
bb_vartables = Union{VarTable,Nothing}[]
198198
for block = 1:length(cfg.blocks)
199199
push!(bb_vartables, VarState[
200-
VarState(slottypes[slot], src.slotflags[slot] & SLOT_USEDUNDEF != 0)
200+
VarState(slottypes[slot], typemin(Int), src.slotflags[slot] & SLOT_USEDUNDEF != 0)
201201
for slot = 1:nslots
202202
])
203203
end

base/compiler/ssair/irinterp.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,9 +49,9 @@ function abstract_eval_phi_stmt(interp::AbstractInterpreter, phi::PhiNode, ::Int
4949
return abstract_eval_phi(interp, phi, nothing, irsv)
5050
end
5151

52-
function abstract_call(interp::AbstractInterpreter, arginfo::ArgInfo, irsv::IRInterpretationState)
52+
function abstract_call(interp::AbstractInterpreter, arginfo::ArgInfo, vtypes::Union{VarTable,Nothing}, irsv::IRInterpretationState)
5353
si = StmtInfo(true) # TODO better job here?
54-
call = abstract_call(interp, arginfo, si, irsv)
54+
call = abstract_call(interp, arginfo, si, vtypes, irsv)
5555
irsv.ir.stmts[irsv.curridx][:info] = call.info
5656
return call
5757
end

base/compiler/tfuncs.jl

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -229,7 +229,7 @@ end
229229

230230
function not_tfunc(𝕃::AbstractLattice, @nospecialize(b))
231231
if isa(b, Conditional)
232-
return Conditional(b.slot, b.elsetype, b.thentype)
232+
return Conditional(b.slot, b.ssadef, b.elsetype, b.thentype)
233233
elseif isa(b, Const)
234234
return Const(not_int(b.val))
235235
end
@@ -350,14 +350,14 @@ end
350350
if isa(x, Conditional)
351351
y = widenconditional(y)
352352
if isa(y, Const)
353-
y.val === false && return Conditional(x.slot, x.elsetype, x.thentype)
353+
y.val === false && return Conditional(x.slot, x.ssadef, x.elsetype, x.thentype)
354354
y.val === true && return x
355355
return Const(false)
356356
end
357357
elseif isa(y, Conditional)
358358
x = widenconditional(x)
359359
if isa(x, Const)
360-
x.val === false && return Conditional(y.slot, y.elsetype, y.thentype)
360+
x.val === false && return Conditional(y.slot, y.ssadef, y.elsetype, y.thentype)
361361
x.val === true && return y
362362
return Const(false)
363363
end
@@ -1415,7 +1415,7 @@ end
14151415
# as well as compute the info for the method matches
14161416
op = unwrapva(argtypes[op_argi])
14171417
v = unwrapva(argtypes[v_argi])
1418-
callinfo = abstract_call(interp, ArgInfo(nothing, Any[op, TF, v]), StmtInfo(true), sv, #=max_methods=#1)
1418+
callinfo = abstract_call(interp, ArgInfo(nothing, Any[op, TF, v]), StmtInfo(true), vtypes, sv, #=max_methods=#1)
14191419
TF2 = tmeet(callinfo.rt, widenconst(TF))
14201420
if TF2 === Bottom
14211421
RT = Bottom
@@ -2931,10 +2931,11 @@ function return_type_tfunc(interp::AbstractInterpreter, argtypes::Vector{Any}, s
29312931
if isa(sv, InferenceState)
29322932
old_restrict = sv.restrict_abstract_call_sites
29332933
sv.restrict_abstract_call_sites = false
2934-
call = abstract_call(interp, ArgInfo(nothing, argtypes_vec), si, sv, #=max_methods=#-1)
2934+
# TODO: vtypes?
2935+
call = abstract_call(interp, ArgInfo(nothing, argtypes_vec), si, nothing, sv, #=max_methods=#-1)
29352936
sv.restrict_abstract_call_sites = old_restrict
29362937
else
2937-
call = abstract_call(interp, ArgInfo(nothing, argtypes_vec), si, sv, #=max_methods=#-1)
2938+
call = abstract_call(interp, ArgInfo(nothing, argtypes_vec), si, nothing, sv, #=max_methods=#-1)
29382939
end
29392940
info = verbose_stmt_info(interp) ? MethodResultPure(ReturnTypeCallInfo(call.info)) : MethodResultPure()
29402941
rt = widenslotwrapper(call.rt)

base/compiler/typeinfer.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -744,7 +744,7 @@ function type_annotate!(interp::AbstractInterpreter, sv::InferenceState)
744744
for slot in 1:nslots
745745
vt = varstate[slot]
746746
widened_type = widenslotwrapper(ignorelimited(vt.typ))
747-
varstate[slot] = VarState(widened_type, vt.undef)
747+
varstate[slot] = VarState(widened_type, vt.ssadef, vt.undef)
748748
end
749749
end
750750
end

base/compiler/typelattice.jl

Lines changed: 50 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ the type of `SlotNumber(cnd.slot)` will be limited by `cnd.thentype`
6060
and in the false branch, it will be limited by `cnd.elsetype`.
6161
Example:
6262
```julia
63-
let cond = isa(x::Union{Int, Float}, Int)::Conditional(x, Int, Float)
63+
let cond = isa(x::Union{Int, Float}, Int)::Conditional(x, _, Int, Float)
6464
if cond
6565
# May assume x is `Int` now
6666
else
@@ -71,21 +71,22 @@ end
7171
"""
7272
struct Conditional
7373
slot::Int
74+
ssadef::Int
7475
thentype
7576
elsetype
7677
# `isdefined` indicates this `Conditional` is from `@isdefined slot`, implying that
7778
# the `undef` information of `slot` can be improved in the then branch.
7879
# Since this is only beneficial for local inference, it is not translated into `InterConditional`.
7980
isdefined::Bool
80-
function Conditional(slot::Int, @nospecialize(thentype), @nospecialize(elsetype);
81+
function Conditional(slot::Int, ssadef::Int, @nospecialize(thentype), @nospecialize(elsetype);
8182
isdefined::Bool=false)
8283
assert_nested_slotwrapper(thentype)
8384
assert_nested_slotwrapper(elsetype)
84-
return new(slot, thentype, elsetype, isdefined)
85+
return new(slot, ssadef, thentype, elsetype, isdefined)
8586
end
8687
end
87-
Conditional(var::SlotNumber, @nospecialize(thentype), @nospecialize(elsetype); isdefined::Bool=false) =
88-
Conditional(slot_id(var), thentype, elsetype; isdefined)
88+
Conditional(var::SlotNumber, ssadef::Int, @nospecialize(thentype), @nospecialize(elsetype); isdefined::Bool=false) =
89+
Conditional(slot_id(var), ssadef, thentype, elsetype; isdefined)
8990

9091
import Core: InterConditional
9192
"""
@@ -105,8 +106,10 @@ InterConditional(var::SlotNumber, @nospecialize(thentype), @nospecialize(elsetyp
105106
InterConditional(slot_id(var), thentype, elsetype)
106107

107108
const AnyConditional = Union{Conditional,InterConditional}
108-
Conditional(cnd::InterConditional) = Conditional(cnd.slot, cnd.thentype, cnd.elsetype)
109-
InterConditional(cnd::Conditional) = InterConditional(cnd.slot, cnd.thentype, cnd.elsetype)
109+
function InterConditional(cnd::Conditional)
110+
@assert cnd.ssadef == 0
111+
InterConditional(cnd.slot, cnd.thentype, cnd.elsetype)
112+
end
110113

111114
"""
112115
alias::MustAlias
@@ -184,8 +187,20 @@ end
184187
struct StateUpdate
185188
var::SlotNumber
186189
vtype::VarState
187-
conditional::Bool
188-
StateUpdate(var::SlotNumber, vtype::VarState, conditional::Bool=false) = new(var, vtype, conditional)
190+
end
191+
192+
"""
193+
Similar to `StateUpdate`, except with the additional guarantee that object identity
194+
is preserved by the update (i.e. `x (before) === x (after)`).
195+
"""
196+
struct StateRefinement
197+
slot::Int
198+
# XXX: This should be an intersection of the old type with the new
199+
# (i.e. newtyp ⊑ oldtyp)
200+
newtyp
201+
undef::Bool
202+
203+
StateRefinement(slot::Int, @nospecialize(newtyp), undef::Bool) = new(slot, newtyp, undef)
189204
end
190205

191206
"""
@@ -328,6 +343,7 @@ end
328343
return false
329344
end
330345

346+
is_same_conditionals(a::Conditional, b::Conditional) = a.slot == b.slot && a.ssadef == b.ssadef
331347
is_same_conditionals(a::C, b::C) where C<:AnyConditional = a.slot == b.slot
332348

333349
@nospecializeinfer is_lattice_bool(lattice::AbstractLattice, @nospecialize(typ)) = typ !== Bottom && (lattice, typ, Bool)
@@ -387,7 +403,7 @@ end
387403
elsefields === nothing || (elsefields[i] = elsetype)
388404
end
389405
end
390-
return Conditional(slot,
406+
return Conditional(slot, typemin(Int), # TODO
391407
thenfields === nothing ? Bottom : PartialStruct(vartyp.typ, thenfields),
392408
elsefields === nothing ? Bottom : PartialStruct(vartyp.typ, elsefields))
393409
else
@@ -404,7 +420,7 @@ end
404420
elsefields === nothing || push!(elsefields, t)
405421
end
406422
end
407-
return Conditional(slot,
423+
return Conditional(slot, typemin(Int),
408424
thenfields === nothing ? Bottom : PartialStruct(vartyp_widened, thenfields),
409425
elsefields === nothing ? Bottom : PartialStruct(vartyp_widened, elsefields))
410426
end
@@ -745,34 +761,39 @@ widenconst(::LimitedAccuracy) = error("unhandled LimitedAccuracy")
745761
# state management #
746762
####################
747763

748-
function smerge(lattice::AbstractLattice, sa::Union{NotFound,VarState}, sb::Union{NotFound,VarState})
764+
function smerge(lattice::AbstractLattice, sa::Union{NotFound,VarState}, sb::Union{NotFound,VarState}, join_pc::Int)
749765
sa === sb && return sa
750766
sa === NOT_FOUND && return sb
751767
sb === NOT_FOUND && return sa
752-
return VarState(tmerge(lattice, sa.typ, sb.typ), sa.undef | sb.undef)
768+
return VarState(tmerge(lattice, sa.typ, sb.typ), sa.ssadef == sb.ssadef ? sa.ssadef : join_pc, sa.undef | sb.undef)
753769
end
754770

755-
@nospecializeinfer @inline schanged(lattice::AbstractLattice, @nospecialize(n), @nospecialize(o)) =
756-
(n !== o) && (o === NOT_FOUND || (n !== NOT_FOUND && !(n.undef <= o.undef && (lattice, n.typ, o.typ))))
771+
@nospecializeinfer @inline schanged(lattice::AbstractLattice, @nospecialize(n), @nospecialize(o), join_pc::Int) =
772+
(n !== o) && (o === NOT_FOUND || (n !== NOT_FOUND && !(n.undef <= o.undef && (n.ssadef == o.ssadef || o.ssadef == join_pc) && (lattice, n.typ, o.typ))))
757773

758774
# remove any lattice elements that wrap the reassigned slot object from the vartable
759-
function invalidate_slotwrapper(vt::VarState, changeid::Int, ignore_conditional::Bool)
775+
function invalidate_slotwrapper(vt::VarState, changeid::Int)
760776
newtyp = ignorelimited(vt.typ)
761-
if (!ignore_conditional && isa(newtyp, Conditional) && newtyp.slot == changeid) ||
762-
(isa(newtyp, MustAlias) && newtyp.slot == changeid)
777+
if ((isa(newtyp, Conditional) && newtyp.slot == changeid) ||
778+
(isa(newtyp, MustAlias) && newtyp.slot == changeid))
763779
newtyp = @noinline widenwrappedslotwrapper(vt.typ)
764-
return VarState(newtyp, vt.undef)
780+
return VarState(newtyp, vt.ssadef, vt.undef)
765781
end
766782
return nothing
767783
end
768784

769-
function stupdate!(lattice::AbstractLattice, state::VarTable, changes::VarTable)
785+
function stupdate!(lattice::AbstractLattice, state::VarTable, changes::VarTable, join_pc::Int)
770786
changed = false
771787
for i = 1:length(state)
772788
newtype = changes[i]
773789
oldtype = state[i]
774-
if schanged(lattice, newtype, oldtype)
775-
state[i] = smerge(lattice, oldtype, newtype)
790+
# In addition to computing the type, the merge here computes the "reaching definition"
791+
# for a slot. The provided `join_pc` is a "virtual" PC, which corresponds to the ϕ-block
792+
# that would exist at the beginning of the BasicBlock.
793+
#
794+
# This effectively applies the "path-convergence criterion" for SSA construction.
795+
if schanged(lattice, newtype, oldtype, join_pc)
796+
state[i] = smerge(lattice, oldtype, newtype, join_pc)
776797
changed = true
777798
end
778799
end
@@ -789,7 +810,7 @@ end
789810
function stoverwrite1!(state::VarTable, change::StateUpdate)
790811
changeid = slot_id(change.var)
791812
for i = 1:length(state)
792-
invalidated = invalidate_slotwrapper(state[i], changeid, change.conditional)
813+
invalidated = invalidate_slotwrapper(state[i], changeid)
793814
if invalidated !== nothing
794815
state[i] = invalidated
795816
end
@@ -799,3 +820,9 @@ function stoverwrite1!(state::VarTable, change::StateUpdate)
799820
state[changeid] = newtype
800821
return state
801822
end
823+
824+
function strefine1!(state::VarTable, refinement::StateRefinement)
825+
(; newtyp, undef, slot) = refinement
826+
state[slot] = VarState(newtyp, state[slot].ssadef, undef)
827+
return state
828+
end

base/compiler/typelimits.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -494,24 +494,24 @@ end
494494
# type-lattice for Conditional wrapper (NOTE never be merged with InterConditional)
495495
if isa(typea, Conditional) && isa(typeb, Const)
496496
if typeb.val === true
497-
typeb = Conditional(typea.slot, Any, Union{})
497+
typeb = Conditional(typea.slot, typea.ssadef, Any, Union{})
498498
elseif typeb.val === false
499-
typeb = Conditional(typea.slot, Union{}, Any)
499+
typeb = Conditional(typea.slot, typea.ssadef, Union{}, Any)
500500
end
501501
end
502502
if isa(typeb, Conditional) && isa(typea, Const)
503503
if typea.val === true
504-
typea = Conditional(typeb.slot, Any, Union{})
504+
typea = Conditional(typeb.slot, typeb.ssadef, Any, Union{})
505505
elseif typea.val === false
506-
typea = Conditional(typeb.slot, Union{}, Any)
506+
typea = Conditional(typeb.slot, typeb.ssadef, Union{}, Any)
507507
end
508508
end
509509
if isa(typea, Conditional) && isa(typeb, Conditional)
510510
if is_same_conditionals(typea, typeb)
511511
thentype = tmerge(widenlattice(lattice), typea.thentype, typeb.thentype)
512512
elsetype = tmerge(widenlattice(lattice), typea.elsetype, typeb.elsetype)
513513
if thentype !== elsetype
514-
return Conditional(typea.slot, thentype, elsetype)
514+
return Conditional(typea.slot, typea.ssadef, thentype, elsetype)
515515
end
516516
end
517517
val = maybe_extract_const_bool(typea)

base/compiler/types.jl

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,14 +47,20 @@ MethodInfo(src::CodeInfo) = MethodInfo(
4747
A special wrapper that represents a local variable of a method being analyzed.
4848
This does not participate in the native type system nor the inference lattice, and it thus
4949
should be always unwrapped to `v.typ` when performing any type or lattice operations on it.
50+
5051
`v.undef` represents undefined-ness of this static parameter. If `true`, it means that the
5152
variable _may_ be undefined at runtime, otherwise it is guaranteed to be defined.
5253
If `v.typ === Bottom` it means that the variable is strictly undefined.
54+
55+
`v.ssadef` represents the "reaching definition" for the variable. If negative, this refers
56+
to a "virtual ϕ-block" preceding the given index. If a slot has the same `ssadef` at two
57+
different points of execution, the slot contents are guaranteed to share identity (`x₀ === x₁`).
5358
"""
5459
struct VarState
5560
typ
61+
ssadef::Int
5662
undef::Bool
57-
VarState(@nospecialize(typ), undef::Bool) = new(typ, undef)
63+
VarState(@nospecialize(typ), ssadef::Int, undef::Bool) = new(typ, ssadef, undef)
5864
end
5965

6066
struct AnalysisResults

base/reflection.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2246,7 +2246,7 @@ function print_statement_costs(io::IO, @nospecialize(tt::Type);
22462246
else
22472247
empty!(cst)
22482248
resize!(cst, length(code.code))
2249-
sptypes = Core.Compiler.VarState[Core.Compiler.VarState(sp, false) for sp in match.sparams]
2249+
sptypes = Core.Compiler.VarState[Core.Compiler.VarState(sp, #= ssadef =# typemin(Int), false) for sp in match.sparams]
22502250
maxcost = Core.Compiler.statement_costs!(cst, code.code, code, sptypes, params)
22512251
nd = ndigits(maxcost)
22522252
irshow_config = IRShow.IRShowConfig() do io, linestart, idx

0 commit comments

Comments
 (0)