Skip to content

Commit 1a8387f

Browse files
martinholtersvtjnash
authored andcommitted
Fix a precision issue in abstract_iteration (#41839)
If the first loop exits in the first iteration, the `statetype` is still `Bottom`. In that case, the new `stateordonet` needs to be determined with the two-arg version of `iterate` again. Explicitly test that inference produces a sound (and reasonably precise) result when splatting an iterator (in this case a long range) that allows constant-propagation up to the `MAX_TUPLE_SPLAT` limit. Fixes #41022 Co-authored-by: Jameson Nash <vtjnash@gmail.com> (cherry picked from commit 92337b5)
1 parent 2116922 commit 1a8387f

File tree

2 files changed

+41
-11
lines changed

2 files changed

+41
-11
lines changed

base/compiler/abstractinterpretation.jl

Lines changed: 25 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -796,9 +796,11 @@ function abstract_iteration(interp::AbstractInterpreter, @nospecialize(itft), @n
796796
return ret, AbstractIterationInfo(calls)
797797
end
798798
if Nothing <: stateordonet_widened || length(ret) >= InferenceParams(interp).MAX_TUPLE_SPLAT
799+
stateordonet = stateordonet_widened
799800
break
800801
end
801802
if !isa(stateordonet_widened, DataType) || !(stateordonet_widened <: Tuple) || isvatuple(stateordonet_widened) || length(stateordonet_widened.parameters) != 2
803+
stateordonet = stateordonet_widened
802804
break
803805
end
804806
nstatetype = getfield_tfunc(stateordonet, Const(2))
@@ -816,27 +818,40 @@ function abstract_iteration(interp::AbstractInterpreter, @nospecialize(itft), @n
816818
end
817819
# From here on, we start asking for results on the widened types, rather than
818820
# the precise (potentially const) state type
819-
statetype = widenconst(statetype)
820-
valtype = widenconst(valtype)
821+
# statetype and valtype are reinitialized in the first iteration below from the
822+
# (widened) stateordonet, which has not yet been fully analyzed in the loop above
823+
statetype = Bottom
824+
valtype = Bottom
825+
may_have_terminated = Nothing <: stateordonet
821826
while valtype !== Any
822-
stateordonet = abstract_call_known(interp, iteratef, nothing, Any[Const(iteratef), itertype, statetype], sv).rt
823-
stateordonet = widenconst(stateordonet)
824-
nounion = typesubtract(stateordonet, Nothing, 0)
825-
if !isa(nounion, DataType) || !(nounion <: Tuple) || isvatuple(nounion) || length(nounion.parameters) != 2
827+
nounion = typeintersect(stateordonet, Tuple{Any,Any})
828+
if nounion !== Union{} && !isa(nounion, DataType)
829+
# nounion is of a type we cannot handle
826830
valtype = Any
827831
break
828832
end
829-
if nounion.parameters[1] <: valtype && nounion.parameters[2] <: statetype
833+
if nounion === Union{} || (nounion.parameters[1] <: valtype && nounion.parameters[2] <: statetype)
834+
# reached a fixpoint or iterator failed/gave invalid answer
830835
if typeintersect(stateordonet, Nothing) === Union{}
831-
# Reached a fixpoint, but Nothing is not possible => iterator is infinite or failing
832-
return Any[Bottom], nothing
836+
# ... but cannot terminate
837+
if !may_have_terminated
838+
# ... and cannot have terminated prior to this loop
839+
return Any[Bottom], nothing
840+
else
841+
# iterator may have terminated prior to this loop, but not during it
842+
valtype = Bottom
843+
end
833844
end
834845
break
835846
end
836847
valtype = tmerge(valtype, nounion.parameters[1])
837848
statetype = tmerge(statetype, nounion.parameters[2])
849+
stateordonet = abstract_call_known(interp, iteratef, nothing, Any[Const(iteratef), itertype, statetype], sv).rt
850+
stateordonet = widenconst(stateordonet)
851+
end
852+
if valtype !== Union{}
853+
push!(ret, Vararg{valtype})
838854
end
839-
push!(ret, Vararg{valtype})
840855
return ret, nothing
841856
end
842857

test/compiler/inference.jl

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2880,9 +2880,24 @@ partial_return_2(x) = Val{partial_return_1(x)[2]}
28802880

28812881
@test Base.return_types(partial_return_2, (Int,)) == Any[Type{Val{1}}]
28822882

2883-
# Precision of abstract_iteration
2883+
# Soundness and precision of abstract_iteration
2884+
f41839() = (1:100...,)
2885+
@test NTuple{100,Int} <: only(Base.return_types(f41839, ())) <: Tuple{Vararg{Int}}
28842886
f_splat(x) = (x...,)
28852887
@test Base.return_types(f_splat, (Pair{Int,Int},)) == Any[Tuple{Int, Int}]
2888+
@test Base.return_types(f_splat, (UnitRange{Int},)) == Any[Tuple{Vararg{Int}}]
2889+
struct Itr41839_1 end # empty or infinite
2890+
Base.iterate(::Itr41839_1) = rand(Bool) ? (nothing, nothing) : nothing
2891+
Base.iterate(::Itr41839_1, ::Nothing) = (nothing, nothing)
2892+
@test Base.return_types(f_splat, (Itr41839_1,)) == Any[Tuple{}]
2893+
struct Itr41839_2 end # empty or failing
2894+
Base.iterate(::Itr41839_2) = rand(Bool) ? (nothing, nothing) : nothing
2895+
Base.iterate(::Itr41839_2, ::Nothing) = error()
2896+
@test Base.return_types(f_splat, (Itr41839_2,)) == Any[Tuple{}]
2897+
struct Itr41839_3 end
2898+
Base.iterate(::Itr41839_3 ) = rand(Bool) ? nothing : (nothing, 1)
2899+
Base.iterate(::Itr41839_3 , i) = i < 16 ? (i, i + 1) : nothing
2900+
@test only(Base.return_types(f_splat, (Itr41839_3,))) <: Tuple{Vararg{Union{Nothing, Int}}}
28862901

28872902
# issue #32699
28882903
f32699(a) = (id = a[1],).id

0 commit comments

Comments
 (0)