Skip to content

Commit d384d78

Browse files
martinholtersvtjnash
authored andcommitted
Fix a precision issue in abstract_iteration (#41839)
If the first loop exits in the first iteration, the `statetype` is still `Bottom`. In that case, the new `stateordonet` needs to be determined with the two-arg version of `iterate` again. Explicitly test that inference produces a sound (and reasonably precise) result when splatting an iterator (in this case a long range) that allows constant-propagation up to the `MAX_TUPLE_SPLAT` limit. Fixes #41022 Co-authored-by: Jameson Nash <vtjnash@gmail.com> (cherry picked from commit 92337b5)
1 parent a7081a6 commit d384d78

File tree

2 files changed

+41
-11
lines changed

2 files changed

+41
-11
lines changed

base/compiler/abstractinterpretation.jl

Lines changed: 25 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -615,9 +615,11 @@ function abstract_iteration(interp::AbstractInterpreter, @nospecialize(itft), @n
615615
return ret, AbstractIterationInfo(calls)
616616
end
617617
if Nothing <: stateordonet_widened || length(ret) >= InferenceParams(interp).MAX_TUPLE_SPLAT
618+
stateordonet = stateordonet_widened
618619
break
619620
end
620621
if !isa(stateordonet_widened, DataType) || !(stateordonet_widened <: Tuple) || isvatuple(stateordonet_widened) || length(stateordonet_widened.parameters) != 2
622+
stateordonet = stateordonet_widened
621623
break
622624
end
623625
nstatetype = getfield_tfunc(stateordonet, Const(2))
@@ -635,27 +637,40 @@ function abstract_iteration(interp::AbstractInterpreter, @nospecialize(itft), @n
635637
end
636638
# From here on, we start asking for results on the widened types, rather than
637639
# the precise (potentially const) state type
638-
statetype = widenconst(statetype)
639-
valtype = widenconst(valtype)
640+
# statetype and valtype are reinitialized in the first iteration below from the
641+
# (widened) stateordonet, which has not yet been fully analyzed in the loop above
642+
statetype = Bottom
643+
valtype = Bottom
644+
may_have_terminated = Nothing <: stateordonet
640645
while valtype !== Any
641-
stateordonet = abstract_call_known(interp, iteratef, nothing, Any[Const(iteratef), itertype, statetype], sv).rt
642-
stateordonet = widenconst(stateordonet)
643-
nounion = typesubtract(stateordonet, Nothing, 0)
644-
if !isa(nounion, DataType) || !(nounion <: Tuple) || isvatuple(nounion) || length(nounion.parameters) != 2
646+
nounion = typeintersect(stateordonet, Tuple{Any,Any})
647+
if nounion !== Union{} && !isa(nounion, DataType)
648+
# nounion is of a type we cannot handle
645649
valtype = Any
646650
break
647651
end
648-
if nounion.parameters[1] <: valtype && nounion.parameters[2] <: statetype
652+
if nounion === Union{} || (nounion.parameters[1] <: valtype && nounion.parameters[2] <: statetype)
653+
# reached a fixpoint or iterator failed/gave invalid answer
649654
if typeintersect(stateordonet, Nothing) === Union{}
650-
# Reached a fixpoint, but Nothing is not possible => iterator is infinite or failing
651-
return Any[Bottom], nothing
655+
# ... but cannot terminate
656+
if !may_have_terminated
657+
# ... and cannot have terminated prior to this loop
658+
return Any[Bottom], nothing
659+
else
660+
# iterator may have terminated prior to this loop, but not during it
661+
valtype = Bottom
662+
end
652663
end
653664
break
654665
end
655666
valtype = tmerge(valtype, nounion.parameters[1])
656667
statetype = tmerge(statetype, nounion.parameters[2])
668+
stateordonet = abstract_call_known(interp, iteratef, nothing, Any[Const(iteratef), itertype, statetype], sv).rt
669+
stateordonet = widenconst(stateordonet)
670+
end
671+
if valtype !== Union{}
672+
push!(ret, Vararg{valtype})
657673
end
658-
push!(ret, Vararg{valtype})
659674
return ret, nothing
660675
end
661676

test/compiler/inference.jl

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2740,9 +2740,24 @@ partial_return_2(x) = Val{partial_return_1(x)[2]}
27402740

27412741
@test Base.return_types(partial_return_2, (Int,)) == Any[Type{Val{1}}]
27422742

2743-
# Precision of abstract_iteration
2743+
# Soundness and precision of abstract_iteration
2744+
f41839() = (1:100...,)
2745+
@test NTuple{100,Int} <: only(Base.return_types(f41839, ())) <: Tuple{Vararg{Int}}
27442746
f_splat(x) = (x...,)
27452747
@test Base.return_types(f_splat, (Pair{Int,Int},)) == Any[Tuple{Int, Int}]
2748+
@test Base.return_types(f_splat, (UnitRange{Int},)) == Any[Tuple{Vararg{Int}}]
2749+
struct Itr41839_1 end # empty or infinite
2750+
Base.iterate(::Itr41839_1) = rand(Bool) ? (nothing, nothing) : nothing
2751+
Base.iterate(::Itr41839_1, ::Nothing) = (nothing, nothing)
2752+
@test Base.return_types(f_splat, (Itr41839_1,)) == Any[Tuple{}]
2753+
struct Itr41839_2 end # empty or failing
2754+
Base.iterate(::Itr41839_2) = rand(Bool) ? (nothing, nothing) : nothing
2755+
Base.iterate(::Itr41839_2, ::Nothing) = error()
2756+
@test Base.return_types(f_splat, (Itr41839_2,)) == Any[Tuple{}]
2757+
struct Itr41839_3 end
2758+
Base.iterate(::Itr41839_3 ) = rand(Bool) ? nothing : (nothing, 1)
2759+
Base.iterate(::Itr41839_3 , i) = i < 16 ? (i, i + 1) : nothing
2760+
@test only(Base.return_types(f_splat, (Itr41839_3,))) <: Tuple{Vararg{Union{Nothing, Int}}}
27462761

27472762
# issue #32699
27482763
f32699(a) = (id = a[1],).id

0 commit comments

Comments
 (0)