Skip to content

Commit d1e58b8

Browse files
authored
Merge pull request #3207 from apalache-mc/igor/fix3204
Fix transitive inlining of polytypes when receiving Quint as input
2 parents 9bd1a07 + d81b0c1 commit d1e58b8

File tree

4 files changed

+144
-29
lines changed

4 files changed

+144
-29
lines changed
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
- Fix transitive inlining of polymorphic Quint definitions (#3207)

tla-pp/src/main/scala/at/forsyte/apalache/tla/pp/Inliner.scala

Lines changed: 44 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ import at.forsyte.apalache.tla.lir.transformations.standard.{
99
}
1010
import at.forsyte.apalache.tla.lir.transformations.{TlaExTransformation, TransformationTracker}
1111
import at.forsyte.apalache.tla.pp.Inliner.{DeclFilter, FilterFun}
12-
import at.forsyte.apalache.tla.types.{Substitution, TypeUnifier, TypeVarPool}
12+
import at.forsyte.apalache.tla.types.{EqClass, Substitution, TypeUnifier, TypeVarPool}
1313

1414
import scala.collection.immutable.SortedMap
1515

@@ -96,31 +96,45 @@ class Inliner(
9696
private def nonNullaryFilter(d: TlaOperDecl): Boolean =
9797
!keepNullaryMono || isPolyTag(d.typeTag) || d.formalParams.nonEmpty
9898

99-
// Given a declaration (possibly holding a polymorphic type) and a monotyped target, computes
100-
// a substitution of the two. A substitution is assumed to exist, otherwise TypingException is thrown.
101-
private def getSubstitution(targetType: TlaType1, decl: TlaOperDecl): (Substitution, TlaType1) = {
102-
val genericType = decl.typeTag.asTlaType1()
103-
val maxUsedVar = Math.max(genericType.usedNames.foldLeft(0)(Math.max), targetType.usedNames.foldLeft(0)(Math.max))
104-
new TypeUnifier(new TypeVarPool(maxUsedVar + 1)).unify(Substitution.empty, genericType, targetType) match {
99+
// Given a declaration `callee` (possibly holding a polymorphic type) and the type at a call site,
100+
// which can also be a polytype, computes a substitution of the two. A substitution is assumed to exist,
101+
// otherwise TypingException is thrown.
102+
private def getSubstitution(callSiteType: TlaType1, callee: TlaOperDecl): (Substitution, TlaType1) = {
103+
val calleeType = callee.typeTag.asTlaType1()
104+
// To fix transitive inlining of polymorphic operators, rename all type variables in `callee`
105+
// to fresh ones, so these type variables have indices larger than those in `callSiteType`.
106+
// The unification algorithm prefers to keep type variables with smaller indices,
107+
// so this ensures that type variables from `callSiteType` are kept.
108+
// See the issue #3204.
109+
val maxUsedVar = ((calleeType.usedNames ++ callSiteType.usedNames) ++ Set(0)).max
110+
val typeVarPool = new TypeVarPool(maxUsedVar + 1)
111+
val calleeSub = Substitution(calleeType.usedNames.map(v => EqClass(v) -> typeVarPool.fresh).toMap)
112+
val calleeTypeRenamed = calleeSub.subRec(calleeType)
113+
114+
new TypeUnifier(typeVarPool).unify(Substitution.empty, callSiteType, calleeTypeRenamed) match {
105115
case None =>
106-
throw new TypingException(
107-
s"Inliner: Unable to unify generic signature $genericType of ${decl.name} with the concrete type $targetType",
108-
decl.ID)
109-
110-
case Some(pair) => pair
116+
throw new TypingException(s"Inliner: Unable to unify the signature $calleeType of ${callee.name} "
117+
+ "with the type $callSiteType at call site", callee.ID)
118+
119+
case Some((unifierSub, unifiedType)) =>
120+
// Now, we have to add the renamed type variables back to the substitution.
121+
// To this end, we compose `calleeSub` with `unifierSub`.
122+
// Not that we are not merging them. Otherwise, the type variables of `callee` might take over.
123+
val composed = Substitution(calleeSub.mapping ++ unifierSub.mapping)
124+
(composed, composed.subRec(unifiedType))
111125
}
112126
}
113127

114-
// Assume an operator declaration named name is in scope.
128+
// Assume an operator declaration named `name` is in scope.
115129
// Creates a fresh copy of the operator body and replaces formal parameter instances with the argument instances.
116130
private def instantiateWithArgs(scope: Scope)(nameEx: NameEx, args: Seq[TlaEx]): TlaEx = {
117131
val name = nameEx.name
118-
val decl = scope(name)
132+
val callee = scope(name)
119133

120-
val freshBody = deepCopy(decl.body)
134+
val freshBody = deepCopy(callee.body)
121135

122136
// All formal parameters get instantiated at once, to avoid parameter-name issues, see #1903
123-
val paramMap = decl.formalParams
137+
val paramMap = callee.formalParams
124138
.zip(args)
125139
.map({ case (OperParam(name, _), arg) =>
126140
name -> arg
@@ -140,28 +154,30 @@ class Inliner(
140154
// To cover both cases at once, we run an additional transform on the replaced body
141155
val newBody = transform(scope)(replacedBody)
142156

143-
// Note: it can happen that the new body and the decl have desynced types (poly vs mono).
144-
// We fix that below with type unification.
145-
// If the operator has a parametric signature, we have to substitute type parameters with concrete parameters
146-
// 1. Unify the operator type with the arguments.
147-
// 2. Apply the resulting substitution to the types in all subexpressions.
148-
val actualType = nameEx.typeTag.asTlaType1()
157+
// Note: it can happen that the type at the call site and the type of `callee` have different types,
158+
// e.g., the `callee` has a more general polytype, or they are both polytypes.
159+
// We fix that below with type unification:
160+
// 1. Unify the operator type with the arguments.
161+
// 2. Apply the resulting substitution to the types in all subexpressions.
162+
// 3. Importantly, we prefer the type variables of the call site, as they are the type variables of
163+
// the caller context.
164+
val callSiteType = nameEx.typeTag.asTlaType1()
149165

150-
val (substitution, _) = getSubstitution(actualType, decl)
166+
val (substitution, _) = getSubstitution(callSiteType, callee)
151167

152168
if (substitution.isEmpty) newBody
153169
else new TypeSubstitutor(tracker, substitution)(newBody)
154170
}
155171

156172
// Assume name is in scope. Creates a local LET-IN for pass-by-name operators.
157173
private def embedPassByName(scope: Scope)(nameEx: NameEx): TlaEx = {
158-
val decl = scope(nameEx.name)
174+
val callee = scope(nameEx.name)
159175

160176
// like in instantiateWithArgs, we compare the declaration type to the expected monotype
161-
val freshBody = deepCopy(decl.body)
177+
val freshBody = deepCopy(callee.body)
162178
val monoOperType = nameEx.typeTag.asTlaType1()
163179

164-
val (substitution, tp) = getSubstitution(monoOperType, decl)
180+
val (substitution, tp) = getSubstitution(monoOperType, callee)
165181

166182
val tpTag = Typed(tp)
167183

@@ -170,8 +186,8 @@ class Inliner(
170186
else new TypeSubstitutor(tracker, substitution)(freshBody)
171187

172188
// To make a local definition, we use a fresh name, derived from the original name, but renamed to get a fresh $N
173-
val newName = renaming.apply(NameEx(decl.name)(decl.typeTag)).asInstanceOf[NameEx].name
174-
val newLocalDecl = TlaOperDecl(newName, decl.formalParams, newBody)(tpTag)
189+
val newName = renaming.apply(NameEx(callee.name)(callee.typeTag)).asInstanceOf[NameEx].name
190+
val newLocalDecl = TlaOperDecl(newName, callee.formalParams, newBody)(tpTag)
175191

176192
LetInEx(NameEx(newName)(tpTag), newLocalDecl)(tpTag)
177193
}

tla-pp/src/test/scala/at/forsyte/apalache/tla/pp/TestInliner.scala

Lines changed: 95 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package at.forsyte.apalache.tla.pp
33
import at.forsyte.apalache.tla.lir.TypedPredefs._
44
import at.forsyte.apalache.tla.lir._
55
import at.forsyte.apalache.tla.lir.convenience.tla
6+
import at.forsyte.apalache.tla.types.{tla => ttla}
67
import at.forsyte.apalache.tla.lir.oper.ApalacheOper
78
import at.forsyte.apalache.tla.lir.transformations.impl.IdleTracker
89
import at.forsyte.apalache.tla.lir.transformations.standard.IncrementalRenaming
@@ -11,11 +12,13 @@ import org.junit.runner.RunWith
1112
import org.scalatest.BeforeAndAfterEach
1213
import org.scalatest.funsuite.AnyFunSuite
1314
import org.scalatestplus.junit.JUnitRunner
15+
import org.scalatestplus.scalacheck.ScalaCheckPropertyChecks
16+
import org.scalatest.prop.TableFor2
1417

1518
import scala.collection.immutable.SortedMap
1619

1720
@RunWith(classOf[JUnitRunner])
18-
class TestInliner extends AnyFunSuite with BeforeAndAfterEach {
21+
class TestInliner extends AnyFunSuite with BeforeAndAfterEach with ScalaCheckPropertyChecks {
1922

2023
import Inliner.Scope
2124

@@ -351,4 +354,95 @@ class TestInliner extends AnyFunSuite with BeforeAndAfterEach {
351354
assert(expected == actual)
352355
}
353356

357+
val typeVariableOrders: TableFor2[Int, Int] = Table(("a", "b"), (0, 1), (1, 0), (0, 2), (2, 0), (1, 2), (2, 1))
358+
359+
test("Unify generic definitions under various type variable namings") {
360+
// Regression test for the issue #3204:
361+
// Inliner failed to unify generics definitions under certain type variable namings.
362+
forAll(typeVariableOrders) { (indexOfA: Int, indexOfB: Int) =>
363+
// P(x) == { x }; Q(y) == P(y); x' := Q({}) ~~> x' := { {} }
364+
val a = VarT1(indexOfA)
365+
val b = VarT1(indexOfB)
366+
// \* @type: a => Set(a);
367+
// P(x) == { x }
368+
val PBody = ttla.enumSet(ttla.name("x", a))
369+
val PDecl = ttla.decl("P", PBody, (OperParam("x"), a))
370+
// \* @type: b => b;
371+
// Q(y) == P(y)
372+
val PTypeInQ = OperT1(Seq(b), SetT1(b))
373+
val QBody = ttla.appOp(ttla.name("P", PTypeInQ), ttla.name("y", b))
374+
val QDecl = ttla.decl("Q", QBody, (OperParam("y"), b))
375+
376+
// @type: () => Bool;
377+
// X() == x' := Q({})
378+
val intSetSet = SetT1(SetT1(IntT1))
379+
val bodyOfX =
380+
ttla.assign(ttla.prime(ttla.name("x", intSetSet)),
381+
ttla
382+
.appOp(
383+
ttla.name("Q", OperT1(Seq(SetT1(IntT1)), SetT1(SetT1(IntT1)))),
384+
ttla.emptySet(IntT1),
385+
))
386+
val declOfX = ttla.decl("X", bodyOfX)
387+
388+
val decls = List(PDecl, QDecl, declOfX)
389+
val inputModule = mkModule(decls: _*)
390+
val outputModule = inlinerKeepNullary.transformModule(inputModule)
391+
val actualBodyOfX = outputModule.declarations(2).asInstanceOf[TlaOperDecl].body
392+
393+
// extract the type of {{}}
394+
val typeOfEmptySet = actualBodyOfX.asInstanceOf[OperEx].args(1).typeTag
395+
assert(typeOfEmptySet == Typed(SetT1(SetT1(IntT1))))
396+
}
397+
}
398+
399+
test("Unify generic definitions under various type variable namings in LET-IN") {
400+
// Regression test for the issue #3204:
401+
// Inliner failed to unify generics definitions under certain type variable namings.
402+
// This is the LET-IN version of the previous test.
403+
forAll(typeVariableOrders) { (indexOfA: Int, indexOfB: Int) =>
404+
// X ==
405+
// LET P(x) == { x }
406+
// Q(y) == P(y)
407+
// IN
408+
// x' := Q({}) ~~> x' := { {} }
409+
val a = VarT1(indexOfA)
410+
val b = VarT1(indexOfB)
411+
// \* @type: a => Set(a);
412+
// P(x) == { x }
413+
val PBody = ttla.enumSet(ttla.name("x", a))
414+
val PDecl = ttla.decl("P", PBody, (OperParam("x"), a))
415+
// \* @type: b => b;
416+
// Q(y) == P(y)
417+
val PTypeInQ = OperT1(Seq(b), SetT1(b))
418+
val QBody = ttla.appOp(ttla.name("P", PTypeInQ), ttla.name("y", b))
419+
val QDecl = ttla.decl("Q", QBody, (OperParam("y"), b))
420+
421+
// @type: () => Bool;
422+
// X() ==
423+
// LET P(x) == { x }
424+
// Q(y) == P(y)
425+
// IN
426+
// x' := Q({})
427+
val intSetSet = SetT1(SetT1(IntT1))
428+
val bodyOfX =
429+
ttla.assign(ttla.prime(ttla.name("x", intSetSet)),
430+
ttla
431+
.appOp(
432+
ttla.name("Q", OperT1(Seq(SetT1(IntT1)), SetT1(SetT1(IntT1)))),
433+
ttla.emptySet(IntT1),
434+
))
435+
val xUnderLetIn = ttla.letIn(ttla.letIn(bodyOfX, QDecl), PDecl)
436+
val declOfX = ttla.decl("X", xUnderLetIn)
437+
438+
val decls = List(declOfX)
439+
val inputModule = mkModule(decls: _*)
440+
val outputModule = inlinerKeepNullary.transformModule(inputModule)
441+
val actualBodyOfX = outputModule.declarations.head.asInstanceOf[TlaOperDecl].body
442+
443+
// extract the type of {{}}
444+
val typeOfEmptySet = actualBodyOfX.asInstanceOf[OperEx].args(1).typeTag
445+
assert(typeOfEmptySet == Typed(SetT1(SetT1(IntT1))))
446+
}
447+
}
354448
}

tlair/src/main/scala/at/forsyte/apalache/tla/types/TypeUnifier.scala

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,10 @@ class TypeUnifier(varPool: TypeVarPool) {
3535
* is the solution set showing how to unify lhs and rhs and t is the type resulting from successfully unifying lhs and
3636
* rhs using mgu. Note that apart from variable substitution, our unification also involves merging record types. When
3737
* there is no unifier, it returns None.
38+
*
39+
* <b>WARNING:</b> When two type variables are unified into a single equivalence class, the variable with the
40+
* <i>smaller</i> index becomes the representative of the class. Thus, unification may substitute `b` with `a`, but
41+
* never `a` with `b`. This property MUST be preserved, as the calling code relies on it.
3842
*/
3943
def unify(substitution: Substitution, lhs: TlaType1, rhs: TlaType1): Option[(Substitution, TlaType1)] = {
4044
// Copy the equivalence classes and the mapping from the substitution.

0 commit comments

Comments
 (0)