Skip to content

Commit 317badd

Browse files
vtjnashKristofferC
authored andcommitted
inference: correction to ifelse Conditional lattice
Rename typeassert_type_instance to tjoin (aka typeintersect). Also, since the ifelse value here might not be in the regular type lattice, we need to use the extended lattice for this evaluation. (cherry picked from commit 95d03f9)
1 parent 19e657b commit 317badd

File tree

3 files changed

+59
-46
lines changed

3 files changed

+59
-46
lines changed

base/compiler/abstractinterpretation.jl

Lines changed: 22 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -788,20 +788,28 @@ end
788788
function abstract_call_builtin(interp::AbstractInterpreter, f::Builtin, fargs::Union{Nothing,Vector{Any}},
789789
argtypes::Vector{Any}, sv::InferenceState, max_methods::Int)
790790
la = length(argtypes)
791-
if f === ifelse && fargs isa Vector{Any} && la == 4 && argtypes[2] isa Conditional
792-
# try to simulate this as a real conditional (`cnd ? x : y`), so that the penalty for using `ifelse` instead isn't too high
793-
cnd = argtypes[2]::Conditional
794-
tx = argtypes[3]
795-
ty = argtypes[4]
796-
a = ssa_def_slot(fargs[3], sv)
797-
b = ssa_def_slot(fargs[4], sv)
798-
if isa(a, Slot) && slot_id(cnd.var) == slot_id(a)
799-
tx = typeintersect(tx, cnd.vtype)
800-
end
801-
if isa(b, Slot) && slot_id(cnd.var) == slot_id(b)
802-
ty = typeintersect(ty, cnd.elsetype)
803-
end
804-
return tmerge(tx, ty)
791+
if f === ifelse && fargs isa Vector{Any} && la == 4
792+
cnd = argtypes[2]
793+
if isa(cnd, Conditional)
794+
newcnd = widenconditional(cnd)
795+
if isa(newcnd, Const)
796+
# if `cnd` is constant, we should just respect its constantness to keep inference accuracy
797+
return newcnd.val ? tx : ty
798+
else
799+
# try to simulate this as a real conditional (`cnd ? x : y`), so that the penalty for using `ifelse` instead isn't too high
800+
tx = argtypes[3]
801+
ty = argtypes[4]
802+
a = ssa_def_slot(fargs[3], sv)
803+
b = ssa_def_slot(fargs[4], sv)
804+
if isa(a, Slot) && slot_id(cnd.var) == slot_id(a)
805+
tx = (cnd.vtype tx ? cnd.vtype : tmeet(tx, widenconst(cnd.vtype)))
806+
end
807+
if isa(b, Slot) && slot_id(cnd.var) == slot_id(b)
808+
ty = (cnd.elsetype ty ? cnd.elsetype : tmeet(ty, widenconst(cnd.elsetype)))
809+
end
810+
return tmerge(tx, ty)
811+
end
812+
end
805813
end
806814
rt = builtin_tfunction(interp, f, argtypes[2:end], sv)
807815
if f === getfield && isa(fargs, Vector{Any}) && la == 3 && isa(argtypes[3], Const) && isa(argtypes[3].val, Int) && argtypes[2] Tuple

base/compiler/tfuncs.jl

Lines changed: 1 addition & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -544,41 +544,10 @@ function typeof_tfunc(@nospecialize(t))
544544
end
545545
add_tfunc(typeof, 1, 1, typeof_tfunc, 0)
546546

547-
function typeassert_type_instance(@nospecialize(v), @nospecialize(t))
548-
if isa(v, Const)
549-
if !has_free_typevars(t) && !isa(v.val, t)
550-
return Bottom
551-
end
552-
return v
553-
elseif isa(v, PartialStruct)
554-
has_free_typevars(t) && return v
555-
widev = widenconst(v)
556-
if widev <: t
557-
return v
558-
elseif typeintersect(widev, t) === Bottom
559-
return Bottom
560-
end
561-
@assert widev <: Tuple
562-
new_fields = Vector{Any}(undef, length(v.fields))
563-
for i = 1:length(new_fields)
564-
new_fields[i] = typeassert_type_instance(v.fields[i], getfield_tfunc(t, Const(i)))
565-
if new_fields[i] === Bottom
566-
return Bottom
567-
end
568-
end
569-
return tuple_tfunc(new_fields)
570-
elseif isa(v, Conditional)
571-
if !(Bool <: t)
572-
return Bottom
573-
end
574-
return v
575-
end
576-
return typeintersect(widenconst(v), t)
577-
end
578547
function typeassert_tfunc(@nospecialize(v), @nospecialize(t))
579548
t = instanceof_tfunc(t)[1]
580549
t === Any && return v
581-
return typeassert_type_instance(v, t)
550+
return tmeet(v, t)
582551
end
583552
add_tfunc(typeassert, 2, 2, typeassert_tfunc, 4)
584553

base/compiler/typelimits.jl

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -521,3 +521,39 @@ function tuplemerge(a::DataType, b::DataType)
521521
end
522522
return Tuple{p...}
523523
end
524+
525+
# compute typeintersect over the extended inference lattice
526+
# where v is in the extended lattice, and t is a Type
527+
function tmeet(@nospecialize(v), @nospecialize(t))
528+
if isa(v, Const)
529+
if !has_free_typevars(t) && !isa(v.val, t)
530+
return Bottom
531+
end
532+
return v
533+
elseif isa(v, PartialStruct)
534+
has_free_typevars(t) && return v
535+
widev = widenconst(v)
536+
if widev <: t
537+
return v
538+
end
539+
ti = typeintersect(widev, t)
540+
if ti === Bottom
541+
return Bottom
542+
end
543+
@assert widev <: Tuple
544+
new_fields = Vector{Any}(undef, length(v.fields))
545+
for i = 1:length(new_fields)
546+
new_fields[i] = tmeet(v.fields[i], widenconst(getfield_tfunc(t, Const(i))))
547+
if new_fields[i] === Bottom
548+
return Bottom
549+
end
550+
end
551+
return tuple_tfunc(new_fields)
552+
elseif isa(v, Conditional)
553+
if !(Bool <: t)
554+
return Bottom
555+
end
556+
return v
557+
end
558+
return typeintersect(widenconst(v), t)
559+
end

0 commit comments

Comments
 (0)