Skip to content

Commit 488158c

Browse files
committed
Fix method invalidation when a call site is used before include/prepend M and a method is added to M later
1 parent e5a75b3 commit 488158c

File tree

1 file changed

+69
-22
lines changed

1 file changed

+69
-22
lines changed

src/main/java/org/truffleruby/core/module/ModuleFields.java

Lines changed: 69 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import java.util.ArrayDeque;
1313
import java.util.ArrayList;
1414
import java.util.Collection;
15+
import java.util.Collections;
1516
import java.util.Deque;
1617
import java.util.HashMap;
1718
import java.util.HashSet;
@@ -21,6 +22,7 @@
2122
import java.util.Map.Entry;
2223
import java.util.NoSuchElementException;
2324
import java.util.Set;
25+
import java.util.WeakHashMap;
2426
import java.util.concurrent.ConcurrentHashMap;
2527
import java.util.concurrent.ConcurrentMap;
2628
import java.util.concurrent.locks.ReentrantLock;
@@ -113,6 +115,17 @@ public static void debugModuleChain(RubyModule module) {
113115
// Concurrency: only modified during boot
114116
private final Map<String, Assumption> inlinedBuiltinsAssumptions = new HashMap<>();
115117

118+
/** A weak set of all classes include-ing or prepend-ing this module M, and of modules which where already prepended
119+
* before to such classes. Used by method invalidation so that lookups before the include/prepend is done can be
120+
* invalidated when a method is added later in M. When M is an included module, defining a method on M only needs to
121+
* invalidate classes in which M is included. However, when M is a prepended module defining a method on M needs to
122+
* invalidate classes and other modules which were prepended before M (since the method could be defined there, and
123+
* then the assumption on M would not be checked). Not that there is no need to be transitive here, include/prepend
124+
* snapshot the modules to include alongside M, and later M.include/M.prepend have no effect on classes already
125+
* including M. In other words,
126+
* {@code module A; end; module B; end; class C; end; C.include A; A.include B; C.ancestors.include?(B) => false} */
127+
private final Set<RubyModule> includedBy;
128+
116129
public ModuleFields(
117130
RubyLanguage language,
118131
SourceSection sourceSection,
@@ -128,6 +141,9 @@ public ModuleFields(
128141
this.constantsUnmodifiedAssumption = new CyclicAssumption("constants are unmodified");
129142
classVariables = new ClassVariableStorage(language);
130143
start = new PrependMarker(this);
144+
this.includedBy = rubyModule instanceof RubyClass
145+
? null
146+
: Collections.newSetFromMap(Collections.synchronizedMap(new WeakHashMap<>()));
131147
}
132148

133149
public RubyConstant getAdoptedByLexicalParent(
@@ -295,11 +311,12 @@ public void include(RubyContext context, Node currentNode, RubyModule module) {
295311

296312
private void performIncludes(ModuleChain inclusionPoint, Deque<RubyModule> moduleAncestors) {
297313
while (!moduleAncestors.isEmpty()) {
298-
RubyModule mod = moduleAncestors.pop();
299-
inclusionPoint.insertAfter(mod);
314+
RubyModule toInclude = moduleAncestors.pop();
315+
inclusionPoint.insertAfter(toInclude);
316+
toInclude.fields.includedBy.add(rubyModule);
300317
// Module#include only adds modules between the current class and the super class,
301318
// so invalidating the current class is enough as all affected lookups would go through the current class.
302-
newMethodsVersion(mod.fields.getMethodNames());
319+
newMethodsVersion(toInclude.fields.getMethodNames());
303320
}
304321
}
305322

@@ -327,31 +344,46 @@ public void prepend(RubyContext context, Node currentNode, RubyModule module) {
327344

328345
SharedObjects.propagate(context.getLanguageSlow(), rubyModule, module);
329346

330-
// Previous calls on instances of the current class must have looked up through the first prepended module,
331-
// so invalidate that one.
332-
final ModuleFields moduleFieldsToInvalidate = getFirstModuleChain().getActualModule().fields;
347+
/* We need to invalidate all prepended modules and the class, because call sites which looked up methods before
348+
* only check the class or one of the prepend module (if the method is defined there). */
349+
final List<RubyModule> prependedModulesAndSelf = getPrependedModulesAndSelf();
333350

334351
ModuleChain mod = module.fields.start;
335352
ModuleChain cur = start;
336353
while (mod != null &&
337354
!(mod instanceof ModuleFields && ((ModuleFields) mod).rubyModule instanceof RubyClass)) {
338355
if (!(mod instanceof PrependMarker)) {
339-
final RubyModule actualModule = mod.getActualModule();
340-
if (!ModuleOperations.includesModule(rubyModule, actualModule)) {
341-
cur.insertAfter(actualModule);
342-
moduleFieldsToInvalidate.newMethodsVersion(actualModule.fields.getMethodNames());
356+
final RubyModule toPrepend = mod.getActualModule();
357+
if (!ModuleOperations.includesModule(rubyModule, toPrepend)) {
358+
cur.insertAfter(toPrepend);
359+
final List<String> methodsToInvalidate = toPrepend.fields.getMethodNames();
360+
for (RubyModule moduleToInvalidate : prependedModulesAndSelf) {
361+
toPrepend.fields.includedBy.add(moduleToInvalidate);
362+
moduleToInvalidate.fields.newMethodsVersion(methodsToInvalidate);
363+
}
343364
cur = cur.getParentModule();
344365
}
345366
}
346367
mod = mod.getParentModule();
347368
}
348369

349370
// If there were already prepended modules, invalidate the first of them
350-
moduleFieldsToInvalidate.newHierarchyVersion();
371+
newHierarchyVersion();
351372

352373
invalidateBuiltinsAssumptions();
353374
}
354375

376+
private List<RubyModule> getPrependedModulesAndSelf() {
377+
final List<RubyModule> prependedModulesAndClass = new ArrayList<>();
378+
ModuleChain chain = getFirstModuleChain();
379+
while (chain != this) {
380+
prependedModulesAndClass.add(chain.getActualModule());
381+
chain = chain.getParentModule();
382+
}
383+
prependedModulesAndClass.add(rubyModule);
384+
return prependedModulesAndClass;
385+
}
386+
355387
/** Set the value of a constant, possibly redefining it. */
356388
@TruffleBoundary
357389
public RubyConstant setConstant(RubyContext context, Node currentNode, String name, Object value) {
@@ -463,6 +495,11 @@ public void addMethod(RubyContext context, Node currentNode, InternalMethod meth
463495
if (previousMethodEntry != null) {
464496
previousMethodEntry.invalidate(rubyModule, method.getName());
465497
}
498+
499+
if (includedBy != null && !includedBy.isEmpty()) {
500+
invalidateIncludedBy(method.getName());
501+
}
502+
466503
// invalidate assumptions to not use an AST-inlined methods
467504
changedMethod(method.getName());
468505
if (refinedModule != null) {
@@ -733,17 +770,27 @@ public void newHierarchyVersion() {
733770
}
734771
}
735772

736-
public void newMethodsVersion(List<String> methodsToInvalidate) {
737-
for (String entryToInvalidate : methodsToInvalidate) {
738-
while (true) {
739-
final MethodEntry methodEntry = methods.get(entryToInvalidate);
740-
if (methodEntry == null) {
741-
break;
742-
} else {
743-
methodEntry.invalidate(rubyModule, entryToInvalidate);
744-
if (methods.replace(entryToInvalidate, methodEntry, methodEntry.withNewAssumption())) {
745-
break;
746-
}
773+
private void invalidateIncludedBy(String method) {
774+
for (RubyModule module : includedBy) {
775+
module.fields.newMethodVersion(method);
776+
}
777+
}
778+
779+
public void newMethodsVersion(Collection<String> methodsToInvalidate) {
780+
for (String name : methodsToInvalidate) {
781+
newMethodVersion(name);
782+
}
783+
}
784+
785+
private void newMethodVersion(String methodToInvalidate) {
786+
while (true) {
787+
final MethodEntry methodEntry = methods.get(methodToInvalidate);
788+
if (methodEntry == null) {
789+
return;
790+
} else {
791+
methodEntry.invalidate(rubyModule, methodToInvalidate);
792+
if (methods.replace(methodToInvalidate, methodEntry, methodEntry.withNewAssumption())) {
793+
return;
747794
}
748795
}
749796
}

0 commit comments

Comments
 (0)