@@ -9,7 +9,7 @@ import at.forsyte.apalache.tla.lir.transformations.standard.{
99}
1010import at .forsyte .apalache .tla .lir .transformations .{TlaExTransformation , TransformationTracker }
1111import at .forsyte .apalache .tla .pp .Inliner .{DeclFilter , FilterFun }
12- import at .forsyte .apalache .tla .types .{Substitution , TypeUnifier , TypeVarPool }
12+ import at .forsyte .apalache .tla .types .{EqClass , Substitution , TypeUnifier , TypeVarPool }
1313
1414import scala .collection .immutable .SortedMap
1515
@@ -96,31 +96,45 @@ class Inliner(
9696 private def nonNullaryFilter (d : TlaOperDecl ): Boolean =
9797 ! keepNullaryMono || isPolyTag(d.typeTag) || d.formalParams.nonEmpty
9898
99- // Given a declaration (possibly holding a polymorphic type) and a monotyped target, computes
100- // a substitution of the two. A substitution is assumed to exist, otherwise TypingException is thrown.
101- private def getSubstitution (targetType : TlaType1 , decl : TlaOperDecl ): (Substitution , TlaType1 ) = {
102- val genericType = decl.typeTag.asTlaType1()
103- val maxUsedVar = Math .max(genericType.usedNames.foldLeft(0 )(Math .max), targetType.usedNames.foldLeft(0 )(Math .max))
104- new TypeUnifier (new TypeVarPool (maxUsedVar + 1 )).unify(Substitution .empty, genericType, targetType) match {
99+ // Given a declaration `callee` (possibly holding a polymorphic type) and the type at a call site,
100+ // which can also be a polytype, computes a substitution of the two. A substitution is assumed to exist,
101+ // otherwise TypingException is thrown.
102+ private def getSubstitution (callSiteType : TlaType1 , callee : TlaOperDecl ): (Substitution , TlaType1 ) = {
103+ val calleeType = callee.typeTag.asTlaType1()
104+ // To fix transitive inlining of polymorphic operators, rename all type variables in `callee`
105+ // to fresh ones, so these type variables have indices larger than those in `callSiteType`.
106+ // The unification algorithm prefers to keep type variables with smaller indices,
107+ // so this ensures that type variables from `callSiteType` are kept.
108+ // See the issue #3204.
109+ val maxUsedVar = ((calleeType.usedNames ++ callSiteType.usedNames) ++ Set (0 )).max
110+ val typeVarPool = new TypeVarPool (maxUsedVar + 1 )
111+ val calleeSub = Substitution (calleeType.usedNames.map(v => EqClass (v) -> typeVarPool.fresh).toMap)
112+ val calleeTypeRenamed = calleeSub.subRec(calleeType)
113+
114+ new TypeUnifier (typeVarPool).unify(Substitution .empty, callSiteType, calleeTypeRenamed) match {
105115 case None =>
106- throw new TypingException (
107- s " Inliner: Unable to unify generic signature $genericType of ${decl.name} with the concrete type $targetType" ,
108- decl.ID )
109-
110- case Some (pair) => pair
116+ throw new TypingException (s " Inliner: Unable to unify the signature $calleeType of ${callee.name} "
117+ + " with the type $callSiteType at call site" , callee.ID )
118+
119+ case Some ((unifierSub, unifiedType)) =>
120+ // Now, we have to add the renamed type variables back to the substitution.
121+ // To this end, we compose `calleeSub` with `unifierSub`.
122+ // Not that we are not merging them. Otherwise, the type variables of `callee` might take over.
123+ val composed = Substitution (calleeSub.mapping ++ unifierSub.mapping)
124+ (composed, composed.subRec(unifiedType))
111125 }
112126 }
113127
114- // Assume an operator declaration named name is in scope.
128+ // Assume an operator declaration named ` name` is in scope.
115129 // Creates a fresh copy of the operator body and replaces formal parameter instances with the argument instances.
116130 private def instantiateWithArgs (scope : Scope )(nameEx : NameEx , args : Seq [TlaEx ]): TlaEx = {
117131 val name = nameEx.name
118- val decl = scope(name)
132+ val callee = scope(name)
119133
120- val freshBody = deepCopy(decl .body)
134+ val freshBody = deepCopy(callee .body)
121135
122136 // All formal parameters get instantiated at once, to avoid parameter-name issues, see #1903
123- val paramMap = decl .formalParams
137+ val paramMap = callee .formalParams
124138 .zip(args)
125139 .map({ case (OperParam (name, _), arg) =>
126140 name -> arg
@@ -140,28 +154,30 @@ class Inliner(
140154 // To cover both cases at once, we run an additional transform on the replaced body
141155 val newBody = transform(scope)(replacedBody)
142156
143- // Note: it can happen that the new body and the decl have desynced types (poly vs mono).
144- // We fix that below with type unification.
145- // If the operator has a parametric signature, we have to substitute type parameters with concrete parameters
146- // 1. Unify the operator type with the arguments.
147- // 2. Apply the resulting substitution to the types in all subexpressions.
148- val actualType = nameEx.typeTag.asTlaType1()
157+ // Note: it can happen that the type at the call site and the type of `callee` have different types,
158+ // e.g., the `callee` has a more general polytype, or they are both polytypes.
159+ // We fix that below with type unification:
160+ // 1. Unify the operator type with the arguments.
161+ // 2. Apply the resulting substitution to the types in all subexpressions.
162+ // 3. Importantly, we prefer the type variables of the call site, as they are the type variables of
163+ // the caller context.
164+ val callSiteType = nameEx.typeTag.asTlaType1()
149165
150- val (substitution, _) = getSubstitution(actualType, decl )
166+ val (substitution, _) = getSubstitution(callSiteType, callee )
151167
152168 if (substitution.isEmpty) newBody
153169 else new TypeSubstitutor (tracker, substitution)(newBody)
154170 }
155171
156172 // Assume name is in scope. Creates a local LET-IN for pass-by-name operators.
157173 private def embedPassByName (scope : Scope )(nameEx : NameEx ): TlaEx = {
158- val decl = scope(nameEx.name)
174+ val callee = scope(nameEx.name)
159175
160176 // like in instantiateWithArgs, we compare the declaration type to the expected monotype
161- val freshBody = deepCopy(decl .body)
177+ val freshBody = deepCopy(callee .body)
162178 val monoOperType = nameEx.typeTag.asTlaType1()
163179
164- val (substitution, tp) = getSubstitution(monoOperType, decl )
180+ val (substitution, tp) = getSubstitution(monoOperType, callee )
165181
166182 val tpTag = Typed (tp)
167183
@@ -170,8 +186,8 @@ class Inliner(
170186 else new TypeSubstitutor (tracker, substitution)(freshBody)
171187
172188 // To make a local definition, we use a fresh name, derived from the original name, but renamed to get a fresh $N
173- val newName = renaming.apply(NameEx (decl .name)(decl .typeTag)).asInstanceOf [NameEx ].name
174- val newLocalDecl = TlaOperDecl (newName, decl .formalParams, newBody)(tpTag)
189+ val newName = renaming.apply(NameEx (callee .name)(callee .typeTag)).asInstanceOf [NameEx ].name
190+ val newLocalDecl = TlaOperDecl (newName, callee .formalParams, newBody)(tpTag)
175191
176192 LetInEx (NameEx (newName)(tpTag), newLocalDecl)(tpTag)
177193 }
0 commit comments