Skip to content

Commit 5a2faa4

Browse files
authored
inference: allow PartialStruct to represent strict undef field (#57541)
This change allows `PartialStruct` to represent structs with strictly uninitialized fields. Now the previous `undef::BitVector` field is changed to `undefs::Vector{Union{Nothing,Bool}}` to encode defined-ness information of each field. Also, this lets us fix the length of `typ::PartialStruct`'s `fields` to always match the number of fields in `typ.typ`. Instead of the current design where the length of `fields` changes depending on the number of initialized fields, it seems simpler to have `PartialStruct`s representing the same `typ` always have the same `fields` length. So, I've included that refactoring as well. - fixes the newly detected error in Oscar.jl
1 parent 9a7fab2 commit 5a2faa4

File tree

10 files changed

+373
-277
lines changed

10 files changed

+373
-277
lines changed

Compiler/src/Compiler.jl

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ using Base: @_foldable_meta, @_gc_preserve_begin, @_gc_preserve_end, @nospeciali
6969
structdiff, tls_world_age, unconstrain_vararg_length, unionlen, uniontype_layout,
7070
uniontypes, unsafe_convert, unwrap_unionall, unwrapva, vect, widen_diagonal,
7171
_uncompressed_ir, maybe_add_binding_backedge!, datatype_min_ninitialized,
72-
partialstruct_undef_length, partialstruct_init_undef
72+
partialstruct_init_undefs, fieldcount_noerror
7373
using Base.Order
7474

7575
import Base: ==, _topmod, append!, convert, copy, copy!, findall, first, get, get!,
@@ -83,10 +83,6 @@ const modifyproperty! = Core.modifyfield!
8383
const replaceproperty! = Core.replacefield!
8484
const _DOCS_ALIASING_WARNING = ""
8585

86-
function _getundef(p::PartialStruct)
87-
Base.getproperty(p, :undef)
88-
end
89-
9086
ccall(:jl_set_istopmod, Cvoid, (Any, Bool), Compiler, false)
9187

9288
eval(x) = Core.eval(Compiler, x)

Compiler/src/abstractinterpretation.jl

Lines changed: 57 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -2115,7 +2115,7 @@ function abstract_call_builtin(interp::AbstractInterpreter, f::Builtin, (; fargs
21152115
end
21162116
return Conditional(a, thentype, elsetype)
21172117
else
2118-
thentype = form_partially_defined_struct(argtype2, argtypes[3])
2118+
thentype = form_partially_defined_struct(𝕃ᵢ, argtype2, argtypes[3])
21192119
if thentype !== nothing
21202120
elsetype = argtype2
21212121
if rt === Const(false)
@@ -2133,22 +2133,32 @@ function abstract_call_builtin(interp::AbstractInterpreter, f::Builtin, (; fargs
21332133
return rt
21342134
end
21352135

2136-
function form_partially_defined_struct(@nospecialize(obj), @nospecialize(name))
2136+
function form_partially_defined_struct(𝕃ᵢ::AbstractLattice, @nospecialize(obj), @nospecialize(name))
21372137
obj isa Const && return nothing # nothing to refine
21382138
name isa Const || return nothing
21392139
objt0 = widenconst(obj)
21402140
objt = unwrap_unionall(objt0)
21412141
objt isa DataType || return nothing
21422142
isabstracttype(objt) && return nothing
2143+
objt <: Tuple && return nothing
21432144
fldidx = try_compute_fieldidx(objt, name.val)
21442145
fldidx === nothing && return nothing
2145-
isa(obj, PartialStruct) && return define_field(obj, fldidx)
2146+
if isa(obj, PartialStruct)
2147+
_getundefs(obj)[fldidx] === false && return nothing
2148+
newundefs = copy(_getundefs(obj))
2149+
newundefs[fldidx] = false
2150+
return PartialStruct(𝕃ᵢ, obj.typ, newundefs, copy(obj.fields))
2151+
end
21462152
nminfld = datatype_min_ninitialized(objt)
2147-
fldidx > nminfld || return nothing
2148-
undef = partialstruct_init_undef(objt, fldidx; all_defined = false)
2149-
undef[fldidx] = false
2150-
fields = Any[fieldtype(objt0, i) for i = 1:fldidx]
2151-
return PartialStruct(fallback_lattice, objt0, undef, fields)
2153+
fldidx nminfld && return nothing
2154+
fldcnt = fieldcount_noerror(objt)::Int
2155+
fields = Any[fieldtype(objt0, i) for i = 1:fldcnt]
2156+
if fields[fldidx] === Union{}
2157+
return nothing # `Union{}` field never transitions to be defined
2158+
end
2159+
undefs = partialstruct_init_undefs(objt, fldcnt)
2160+
undefs[fldidx] = false
2161+
return PartialStruct(𝕃ᵢ, objt0, undefs, fields)
21522162
end
21532163

21542164
function abstract_call_unionall(interp::AbstractInterpreter, argtypes::Vector{Any}, call::CallMeta)
@@ -2663,7 +2673,7 @@ function abstract_call_known(interp::AbstractInterpreter, @nospecialize(f),
26632673
# so we try to encode that information with a `PartialStruct`
26642674
farg2 = ssa_def_slot(fargs[2], sv)
26652675
if farg2 isa SlotNumber
2666-
refined = form_partially_defined_struct(argtypes[2], argtypes[3])
2676+
refined = form_partially_defined_struct(𝕃ᵢ, argtypes[2], argtypes[3])
26672677
if refined !== nothing
26682678
refinements = SlotRefinement(farg2, refined)
26692679
end
@@ -3035,7 +3045,6 @@ function abstract_eval_call(interp::AbstractInterpreter, e::Expr, sstate::Statem
30353045
end
30363046
end
30373047

3038-
30393048
function abstract_eval_new(interp::AbstractInterpreter, e::Expr, sstate::StatementState,
30403049
sv::AbsIntState)
30413050
𝕃ᵢ = typeinf_lattice(interp)
@@ -3093,7 +3102,30 @@ function abstract_eval_new(interp::AbstractInterpreter, e::Expr, sstate::Stateme
30933102
# - any refinement information is available (`anyrefine`), or when
30943103
# - `nargs` is greater than `n_initialized` derived from the struct type
30953104
# information alone
3096-
rt = PartialStruct(𝕃ᵢ, rt, ats)
3105+
undefs = Union{Nothing,Bool}[false for _ in 1:nargs]
3106+
if nargs < fcount # fill in uninitialized fields
3107+
for i = (nargs+1):fcount
3108+
ft = fieldtype(rt, i)
3109+
push!(ats, ft)
3110+
if ft === Union{} # `Union{}`-typed field is never initialized
3111+
push!(undefs, true)
3112+
elseif isconcretetype(ft) && datatype_pointerfree(ft) # this check is probably incomplete
3113+
push!(undefs, false)
3114+
# TODO If we can implement the query such that it accurately
3115+
# identifies fields that never be `#undef'd, we can make the
3116+
# following improvements:
3117+
# elseif is_field_pointerfree(rt, i)
3118+
# push!(undefs, false)
3119+
# elseif ismutable && !isconst(rt, i) # can't constrain this field (as it may be modified later)
3120+
# push!(undefs, nothing)
3121+
# else
3122+
# push!(undefs, true)
3123+
else
3124+
push!(undefs, nothing)
3125+
end
3126+
end
3127+
end
3128+
rt = PartialStruct(𝕃ᵢ, rt, undefs, ats)
30973129
end
30983130
else
30993131
rt = refine_partial_type(rt)
@@ -3122,13 +3154,18 @@ function abstract_eval_splatnew(interp::AbstractInterpreter, e::Expr, sstate::St
31223154
end))
31233155
nothrow = isexact
31243156
rt = Const(ccall(:jl_new_structt, Any, (Any, Any), rt, at.val))
3125-
elseif (isa(at, PartialStruct) && (𝕃ᵢ, at, Tuple) && n > 0 &&
3126-
n == length(at.fields::Vector{Any}) && !isvarargtype(at.fields[end]) &&
3127-
(let t = rt, at = at
3128-
all(i::Int -> (𝕃ᵢ, (at.fields::Vector{Any})[i], fieldtype(t, i)), 1:n)
3129-
end))
3130-
nothrow = isexact
3131-
rt = PartialStruct(𝕃ᵢ, rt, at.fields::Vector{Any})
3157+
elseif at isa PartialStruct
3158+
if (𝕃ᵢ, at, Tuple) && n > 0
3159+
fields = at.fields
3160+
if (n == length(fields) && !isvarargtype(fields[end]) &&
3161+
(let t = rt
3162+
all(i::Int -> (𝕃ᵢ, fields[i], fieldtype(t, i)), 1:n)
3163+
end))
3164+
nothrow = isexact
3165+
undefs = Union{Nothing,Bool}[false for _ in 1:n]
3166+
rt = PartialStruct(𝕃ᵢ, rt, undefs, fields)
3167+
end
3168+
end
31323169
end
31333170
else
31343171
rt = refine_partial_type(rt)
@@ -3713,7 +3750,7 @@ end
37133750
@nospecializeinfer function widenreturn_partials(𝕃ᵢ::PartialsLattice, @nospecialize(rt), info::BestguessInfo)
37143751
if isa(rt, PartialStruct)
37153752
fields = copy(rt.fields)
3716-
anyrefine = refines_definedness_information(rt)
3753+
anyrefine = n_initialized(rt) > datatype_min_ninitialized(rt.typ)
37173754
𝕃 = typeinf_lattice(info.interp)
37183755
= strictpartialorder(𝕃)
37193756
for i in 1:length(fields)
@@ -3725,7 +3762,7 @@ end
37253762
end
37263763
fields[i] = a
37273764
end
3728-
anyrefine && return PartialStruct(𝕃ᵢ, rt.typ, rt.undef, fields)
3765+
anyrefine && return PartialStruct(𝕃ᵢ, rt.typ, _getundefs(rt), fields)
37293766
end
37303767
if isa(rt, PartialOpaque)
37313768
return rt # XXX: this case was missed in #39512

Compiler/src/tfuncs.jl

Lines changed: 12 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -439,8 +439,9 @@ end
439439
end
440440
elseif isa(arg1, PartialStruct)
441441
if !isvarargtype(arg1.fields[end])
442-
if !is_field_maybe_undef(arg1, idx)
443-
return Const(true)
442+
aundefᵢ = _getundefs(arg1)[idx]
443+
if aundefᵢ isa Bool
444+
return Const(!aundefᵢ)
444445
end
445446
end
446447
elseif !isvatuple(a1)
@@ -906,31 +907,6 @@ add_tfunc(<:, 2, 2, subtype_tfunc, 10)
906907
return lty Type && rty Type
907908
end
908909

909-
function fieldcount_noerror(@nospecialize t)
910-
if t isa UnionAll || t isa Union
911-
t = argument_datatype(t)
912-
if t === nothing
913-
return nothing
914-
end
915-
elseif t === Union{}
916-
return 0
917-
end
918-
t isa DataType || return nothing
919-
if t.name === _NAMEDTUPLE_NAME
920-
names, types = t.parameters
921-
if names isa Tuple
922-
return length(names)
923-
end
924-
if types isa DataType && types <: Tuple
925-
return fieldcount_noerror(types)
926-
end
927-
return nothing
928-
elseif isabstracttype(t) || (t.name === Tuple.name && isvatuple(t))
929-
return nothing
930-
end
931-
return isdefined(t, :types) ? length(t.types) : length(t.name.names)
932-
end
933-
934910
function try_compute_fieldidx(@nospecialize(typ), @nospecialize(field))
935911
typ = argument_datatype(typ)
936912
typ === nothing && return nothing
@@ -1141,8 +1117,12 @@ end
11411117
sty = unwrap_unionall(s)::DataType
11421118
if isa(name, Const)
11431119
nv = _getfield_fieldindex(sty, name)
1144-
if isa(nv, Int) && !is_field_maybe_undef(s00, nv)
1145-
return unwrapva(partialstruct_getfield(s00, nv))
1120+
if isa(nv, Int)
1121+
if nv < 1
1122+
return Bottom
1123+
elseif nv length(s00.fields)
1124+
return unwrapva(s00.fields[nv])
1125+
end
11461126
end
11471127
end
11481128
s00 = s
@@ -1437,7 +1417,7 @@ end
14371417
if TF2 === Bottom
14381418
RT = Bottom
14391419
elseif isconcretetype(RT) && has_nontrivial_extended_info(𝕃ᵢ, TF2) # isconcrete condition required to form a PartialStruct
1440-
RT = PartialStruct(fallback_lattice, RT, Any[TF, TF2])
1420+
RT = PartialStruct(fallback_lattice, RT, Union{Nothing,Bool}[false,false], Any[TF, TF2])
14411421
end
14421422
info = ModifyOpInfo(callinfo.info)
14431423
return CallMeta(RT, Any, Effects(), info)
@@ -2015,15 +1995,15 @@ function tuple_tfunc(𝕃::AbstractLattice, argtypes::Vector{Any})
20151995
typ = Tuple{params...}
20161996
# replace a singleton type with its equivalent Const object
20171997
issingletontype(typ) && return Const(typ.instance)
2018-
return anyinfo ? PartialStruct(𝕃, typ, argtypes) : typ
1998+
return anyinfo ? PartialStruct(𝕃, typ, partialstruct_init_undefs(typ, argtypes), argtypes) : typ
20191999
end
20202000

20212001
@nospecs function memorynew_tfunc(𝕃::AbstractLattice, memtype, memlen)
20222002
hasintersect(widenconst(memlen), Int) || return Bottom
20232003
memt = tmeet(𝕃, instanceof_tfunc(memtype, true)[1], GenericMemory)
20242004
memt == Union{} && return memt
20252005
# PartialStruct so that loads of Const `length` get inferred
2026-
return PartialStruct(𝕃, memt, Any[memlen, Ptr{Nothing}])
2006+
return PartialStruct(𝕃, memt, Union{Nothing,Bool}[false,false], Any[memlen, Ptr{Nothing}])
20272007
end
20282008
add_tfunc(Core.memorynew, 2, 2, memorynew_tfunc, 10)
20292009

Compiler/src/typeinfer.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -514,7 +514,7 @@ function finishinfer!(me::InferenceState, interp::AbstractInterpreter, cycleid::
514514
rettype_const = result_type.parameters[1]
515515
const_flags = 0x2
516516
elseif isa(result_type, PartialStruct)
517-
rettype_const = (_getundef(result_type), result_type.fields)
517+
rettype_const = (_getundefs(result_type), result_type.fields)
518518
const_flags = 0x2
519519
elseif isa(result_type, InterConditional)
520520
rettype_const = result_type
@@ -958,9 +958,9 @@ function cached_return_type(code::CodeInstance)
958958
rettype_const = code.rettype_const
959959
# the second subtyping/egal conditions are necessary to distinguish usual cases
960960
# from rare cases when `Const` wrapped those extended lattice type objects
961-
if isa(rettype_const, Tuple{BitVector, Vector{Any}}) && !(Tuple{BitVector, Vector{Any}} <: rettype)
962-
undef, fields = rettype_const
963-
return PartialStruct(fallback_lattice, rettype, undef, fields)
961+
if isa(rettype_const, Tuple{Vector{Union{Nothing,Bool}}, Vector{Any}}) && !(Tuple{Vector{Union{Nothing,Bool}}, Vector{Any}} <: rettype)
962+
undefs, fields = rettype_const
963+
return PartialStruct(fallback_lattice, rettype, undefs, fields)
964964
elseif isa(rettype_const, PartialOpaque) && rettype <: Core.OpaqueClosure
965965
return rettype_const
966966
elseif isa(rettype_const, InterConditional) && rettype !== InterConditional

Compiler/src/typelattice.jl

Lines changed: 25 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -318,15 +318,15 @@ end
318318
fields = vartyp.fields
319319
thenfields = thentype === Bottom ? nothing : copy(fields)
320320
elsefields = elsetype === Bottom ? nothing : copy(fields)
321-
undef = copy(_getundef(vartyp))
321+
undefs = copy(_getundefs(vartyp))
322322
if 1 fldidx length(fields)
323323
thenfields === nothing || (thenfields[fldidx] = thentype)
324324
elsefields === nothing || (elsefields[fldidx] = elsetype)
325-
undef[fldidx] = false
325+
undefs[fldidx] = false
326326
end
327327
return Conditional(slot,
328-
thenfields === nothing ? Bottom : PartialStruct(fallback_lattice, vartyp.typ, undef, thenfields),
329-
elsefields === nothing ? Bottom : PartialStruct(fallback_lattice, vartyp.typ, undef, elsefields))
328+
thenfields === nothing ? Bottom : PartialStruct(fallback_lattice, vartyp.typ, undefs, thenfields),
329+
elsefields === nothing ? Bottom : PartialStruct(fallback_lattice, vartyp.typ, undefs, elsefields))
330330
else
331331
vartyp_widened = widenconst(vartyp)
332332
thenfields = thentype === Bottom ? nothing : Any[]
@@ -424,21 +424,15 @@ end
424424
if isa(a, PartialStruct)
425425
if isa(b, PartialStruct)
426426
a.typ <: b.typ || return false
427-
if length(a.fields) length(b.fields)
428-
if !(isvarargtype(a.fields[end]) || isvarargtype(b.fields[end]))
429-
length(a.fields) length(b.fields) || return false
430-
else
427+
nflds = length(a.fields)
428+
nflds == length(b.fields) || return false
429+
for i in 1:nflds
430+
if !(_getundefs(b)[i] === nothing || _getundefs(a)[i] === _getundefs(b)[i])
431431
return false
432432
end
433-
end
434-
na = length(a.fields)
435-
nb = length(b.fields)
436-
nmax = max(na, nb)
437-
for i in 1:nmax
438-
is_field_maybe_undef(a, i) is_field_maybe_undef(b, i) || return false
439-
af = partialstruct_getfield(a, i)
440-
bf = partialstruct_getfield(b, i)
441-
if i == na || i == nb
433+
af = a.fields[i]
434+
bf = b.fields[i]
435+
if i == nflds
442436
if isvarargtype(af)
443437
# If `af` is vararg, so must bf by the <: above
444438
@assert isvarargtype(bf)
@@ -474,10 +468,13 @@ end
474468
nf = nfields(a.val)
475469
for i in 1:nf
476470
if !isdefined(a.val, i)
477-
is_field_maybe_undef(b, i) || return false # conflicting defined-ness information
471+
_getundefs(b)[i] === false && return false # conflicting defined-ness information
478472
continue # since ∀ T Union{} ⊑ T
479473
end
480474
i > length(b.fields) && break # `a` has more information than `b` that is partially initialized struct
475+
if _getundefs(b)[i] === true
476+
return false # conflicting defined-ness information
477+
end
481478
bfᵢ = b.fields[i]
482479
if i == nf
483480
bfᵢ = unwrapva(bfᵢ)
@@ -548,7 +545,7 @@ end
548545
if isa(a, PartialStruct)
549546
isa(b, PartialStruct) || return false
550547
length(a.fields) == length(b.fields) || return false
551-
_getundef(a) == _getundef(b) || return false
548+
_getundefs(a) == _getundefs(b) || return false
552549
widenconst(a) == widenconst(b) || return false
553550
a.fields === b.fields && return true # fast path
554551
for i in 1:length(a.fields)
@@ -756,14 +753,18 @@ end
756753
# different instances of the compiler that may share the `Core.PartialStruct`
757754
# type.
758755

759-
function Core.PartialStruct(𝕃::AbstractLattice, @nospecialize(typ), fields::Vector{Any}; all_defined::Bool = true)
760-
undef = partialstruct_init_undef(typ, fields; all_defined)
761-
return PartialStruct(𝕃, typ, undef, fields)
756+
# Legacy constructor
757+
function Core.PartialStruct(𝕃::AbstractLattice, @nospecialize(typ), fields::Vector{Any})
758+
return PartialStruct(𝕃, typ, partialstruct_init_undefs(typ, fields), fields)
762759
end
763760

764-
function Core.PartialStruct(::AbstractLattice, @nospecialize(typ), undef::BitVector, fields::Vector{Any})
761+
function Core.PartialStruct(::AbstractLattice, @nospecialize(typ), undefs::Vector{Union{Nothing,Bool}}, fields::Vector{Any})
765762
for i = 1:length(fields)
766763
assert_nested_slotwrapper(fields[i])
767764
end
768-
return PartialStruct(typ, undef, fields)
765+
return PartialStruct(typ, undefs, fields)
769766
end
767+
768+
# a special getter for `PartialStruct` to achieve better type stability:
769+
# `(x::PartialStruct).undefs` will be lowered to `getfield(x, :undefs)::Any` otherwise
770+
_getundefs(p::PartialStruct) = Base.getproperty(p, :undefs)

0 commit comments

Comments
 (0)