Skip to content

Commit cdfb356

Browse files
committed
[GR-28711] Fix method invalidation when a call site is used before include/prepend M and a method is added to M later
PullRequest: truffleruby/2659
2 parents 4e214cf + 9a88834 commit cdfb356

File tree

3 files changed

+225
-33
lines changed

3 files changed

+225
-33
lines changed

spec/ruby/core/module/include_spec.rb

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -263,6 +263,30 @@ def foo
263263
foo.call.should == 'm'
264264
end
265265

266+
267+
it "updates the method when a module included after a call is later updated" do
268+
m_module = Module.new
269+
a_class = Class.new do
270+
def foo
271+
'a'
272+
end
273+
end
274+
b_class = Class.new(a_class)
275+
b = b_class.new
276+
foo = -> { b.foo }
277+
foo.call.should == 'a'
278+
279+
b_class.include m_module
280+
foo.call.should == 'a'
281+
282+
m_module.module_eval do
283+
def foo
284+
"m"
285+
end
286+
end
287+
foo.call.should == 'm'
288+
end
289+
266290
it "updates the method when a nested included module is updated" do
267291
a_class = Class.new do
268292
def foo

spec/ruby/core/module/prepend_spec.rb

Lines changed: 124 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -47,9 +47,8 @@ def foo
4747
'a'
4848
end
4949
end
50-
b_class = Class.new(a_class)
51-
b = b_class.new
52-
foo = -> { b.foo }
50+
a = a_class.new
51+
foo = -> { a.foo }
5352
foo.call.should == 'a'
5453
a_class.class_eval do
5554
prepend m_module
@@ -65,9 +64,8 @@ def foo
6564
'a'
6665
end
6766
end
68-
b_class = Class.new(a_class)
69-
b = b_class.new
70-
foo = -> { b.foo }
67+
a = a_class.new
68+
foo = -> { a.foo }
7169
foo.call.should == 'a'
7270
m_module.module_eval do
7371
def foo
@@ -77,6 +75,124 @@ def foo
7775
foo.call.should == 'm'
7876
end
7977

78+
it "updates the method when there is a base included method and the prepended module overrides it" do
79+
base_module = Module.new do
80+
def foo
81+
'a'
82+
end
83+
end
84+
a_class = Class.new do
85+
include base_module
86+
end
87+
a = a_class.new
88+
foo = -> { a.foo }
89+
foo.call.should == 'a'
90+
91+
m_module = Module.new do
92+
def foo
93+
"m"
94+
end
95+
end
96+
a_class.prepend m_module
97+
foo.call.should == 'm'
98+
end
99+
100+
it "updates the method when there is a base included method and the prepended module is later updated" do
101+
base_module = Module.new do
102+
def foo
103+
'a'
104+
end
105+
end
106+
a_class = Class.new do
107+
include base_module
108+
end
109+
a = a_class.new
110+
foo = -> { a.foo }
111+
foo.call.should == 'a'
112+
113+
m_module = Module.new
114+
a_class.prepend m_module
115+
foo.call.should == 'a'
116+
117+
m_module.module_eval do
118+
def foo
119+
"m"
120+
end
121+
end
122+
foo.call.should == 'm'
123+
end
124+
125+
it "updates the method when a module prepended after a call is later updated" do
126+
m_module = Module.new
127+
a_class = Class.new do
128+
def foo
129+
'a'
130+
end
131+
end
132+
a = a_class.new
133+
foo = -> { a.foo }
134+
foo.call.should == 'a'
135+
136+
a_class.prepend m_module
137+
foo.call.should == 'a'
138+
139+
m_module.module_eval do
140+
def foo
141+
"m"
142+
end
143+
end
144+
foo.call.should == 'm'
145+
end
146+
147+
it "updates the method when a module is prepended after another and the method is defined later on that module" do
148+
m_module = Module.new do
149+
def foo
150+
'a'
151+
end
152+
end
153+
a_class = Class.new
154+
a_class.prepend m_module
155+
a = a_class.new
156+
foo = -> { a.foo }
157+
foo.call.should == 'a'
158+
159+
n_module = Module.new
160+
a_class.prepend n_module
161+
foo.call.should == 'a'
162+
163+
n_module.module_eval do
164+
def foo
165+
"n"
166+
end
167+
end
168+
foo.call.should == 'n'
169+
end
170+
171+
it "updates the method when a module is included in a prepended module and the method is defined later" do
172+
a_class = Class.new
173+
base_module = Module.new do
174+
def foo
175+
'a'
176+
end
177+
end
178+
a_class.prepend base_module
179+
a = a_class.new
180+
foo = -> { a.foo }
181+
foo.call.should == 'a'
182+
183+
m_module = Module.new
184+
n_module = Module.new
185+
m_module.include n_module
186+
a_class.prepend m_module
187+
188+
n_module.module_eval do
189+
def foo
190+
"n"
191+
end
192+
end
193+
foo.call.should == 'n'
194+
end
195+
80196
it "updates the method when a new module with an included module is prepended" do
81197
a_class = Class.new do
82198
def foo
@@ -94,9 +210,8 @@ def foo
94210
include n_module
95211
end
96212

97-
b_class = Class.new(a_class)
98-
b = b_class.new
99-
foo = -> { b.foo }
213+
a = a_class.new
214+
foo = -> { a.foo }
100215

101216
foo.call.should == 'a'
102217

src/main/java/org/truffleruby/core/module/ModuleFields.java

Lines changed: 77 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import java.util.ArrayDeque;
1313
import java.util.ArrayList;
1414
import java.util.Collection;
15+
import java.util.Collections;
1516
import java.util.Deque;
1617
import java.util.HashMap;
1718
import java.util.HashSet;
@@ -21,6 +22,7 @@
2122
import java.util.Map.Entry;
2223
import java.util.NoSuchElementException;
2324
import java.util.Set;
25+
import java.util.WeakHashMap;
2426
import java.util.concurrent.ConcurrentHashMap;
2527
import java.util.concurrent.ConcurrentMap;
2628
import java.util.concurrent.locks.ReentrantLock;
@@ -113,6 +115,17 @@ public static void debugModuleChain(RubyModule module) {
113115
// Concurrency: only modified during boot
114116
private final Map<String, Assumption> inlinedBuiltinsAssumptions = new HashMap<>();
115117

118+
/** A weak set of all classes include-ing or prepend-ing this module M, and of modules which where already prepended
119+
* before to such classes. Used by method invalidation so that lookups before the include/prepend is done can be
120+
* invalidated when a method is added later in M. When M is an included module, defining a method on M only needs to
121+
* invalidate classes in which M is included. However, when M is a prepended module defining a method on M needs to
122+
* invalidate classes and other modules which were prepended before M (since the method could be defined there, and
123+
* then the assumption on M would not be checked). Not that there is no need to be transitive here, include/prepend
124+
* snapshot the modules to include alongside M, and later M.include/M.prepend have no effect on classes already
125+
* including M. In other words,
126+
* {@code module A; end; module B; end; class C; end; C.include A; A.include B; C.ancestors.include?(B) => false} */
127+
private final Set<RubyModule> includedBy;
128+
116129
public ModuleFields(
117130
RubyLanguage language,
118131
SourceSection sourceSection,
@@ -128,6 +141,9 @@ public ModuleFields(
128141
this.constantsUnmodifiedAssumption = new CyclicAssumption("constants are unmodified");
129142
classVariables = new ClassVariableStorage(language);
130143
start = new PrependMarker(this);
144+
this.includedBy = rubyModule instanceof RubyClass
145+
? null
146+
: Collections.newSetFromMap(Collections.synchronizedMap(new WeakHashMap<>()));
131147
}
132148

133149
public RubyConstant getAdoptedByLexicalParent(
@@ -295,11 +311,14 @@ public void include(RubyContext context, Node currentNode, RubyModule module) {
295311

296312
private void performIncludes(ModuleChain inclusionPoint, Deque<RubyModule> moduleAncestors) {
297313
while (!moduleAncestors.isEmpty()) {
298-
RubyModule mod = moduleAncestors.pop();
299-
inclusionPoint.insertAfter(mod);
300-
// Module#include only adds modules between the current class and the super class,
301-
// so invalidating the current class is enough as all affected lookups would go through the current class.
302-
newMethodsVersion(mod.fields.getMethodNames());
314+
RubyModule toInclude = moduleAncestors.pop();
315+
inclusionPoint.insertAfter(toInclude);
316+
if (rubyModule instanceof RubyClass) { // M.include(N) just registers N but does nothing until C.include/prepend(M)
317+
toInclude.fields.includedBy.add(rubyModule);
318+
// Module#include only adds modules between the current class and the super class,
319+
// so invalidating the current class is enough as all affected lookups would go through the current class.
320+
newMethodsVersion(toInclude.fields.getMethodNames());
321+
}
303322
}
304323
}
305324

@@ -327,31 +346,50 @@ public void prepend(RubyContext context, Node currentNode, RubyModule module) {
327346

328347
SharedObjects.propagate(context.getLanguageSlow(), rubyModule, module);
329348

330-
// Previous calls on instances of the current class must have looked up through the first prepended module,
331-
// so invalidate that one.
332-
final ModuleFields moduleFieldsToInvalidate = getFirstModuleChain().getActualModule().fields;
349+
/* We need to invalidate all prepended modules and the class, because call sites which looked up methods before
350+
* only check the class or one of the prepend module (if the method is defined there). */
351+
final List<RubyModule> prependedModulesAndClass = rubyModule instanceof RubyClass
352+
? getPrependedModulesAndClass()
353+
: null;
333354

334355
ModuleChain mod = module.fields.start;
335356
ModuleChain cur = start;
336357
while (mod != null &&
337358
!(mod instanceof ModuleFields && ((ModuleFields) mod).rubyModule instanceof RubyClass)) {
338359
if (!(mod instanceof PrependMarker)) {
339-
final RubyModule actualModule = mod.getActualModule();
340-
if (!ModuleOperations.includesModule(rubyModule, actualModule)) {
341-
cur.insertAfter(actualModule);
342-
moduleFieldsToInvalidate.newMethodsVersion(actualModule.fields.getMethodNames());
360+
final RubyModule toPrepend = mod.getActualModule();
361+
if (!ModuleOperations.includesModule(rubyModule, toPrepend)) {
362+
cur.insertAfter(toPrepend);
363+
if (rubyModule instanceof RubyClass) { // M.prepend(N) just registers N but does nothing until C.prepend/include(M)
364+
final List<String> methodsToInvalidate = toPrepend.fields.getMethodNames();
365+
for (RubyModule moduleToInvalidate : prependedModulesAndClass) {
366+
toPrepend.fields.includedBy.add(moduleToInvalidate);
367+
moduleToInvalidate.fields.newMethodsVersion(methodsToInvalidate);
368+
}
369+
}
343370
cur = cur.getParentModule();
344371
}
345372
}
346373
mod = mod.getParentModule();
347374
}
348375

349376
// If there were already prepended modules, invalidate the first of them
350-
moduleFieldsToInvalidate.newHierarchyVersion();
377+
newHierarchyVersion();
351378

352379
invalidateBuiltinsAssumptions();
353380
}
354381

382+
private List<RubyModule> getPrependedModulesAndClass() {
383+
final List<RubyModule> prependedModulesAndClass = new ArrayList<>();
384+
ModuleChain chain = getFirstModuleChain();
385+
while (chain != this) {
386+
prependedModulesAndClass.add(chain.getActualModule());
387+
chain = chain.getParentModule();
388+
}
389+
prependedModulesAndClass.add(rubyModule);
390+
return prependedModulesAndClass;
391+
}
392+
355393
/** Set the value of a constant, possibly redefining it. */
356394
@TruffleBoundary
357395
public RubyConstant setConstant(RubyContext context, Node currentNode, String name, Object value) {
@@ -463,6 +501,11 @@ public void addMethod(RubyContext context, Node currentNode, InternalMethod meth
463501
if (previousMethodEntry != null) {
464502
previousMethodEntry.invalidate(rubyModule, method.getName());
465503
}
504+
505+
if (includedBy != null && !includedBy.isEmpty()) {
506+
invalidateIncludedBy(method.getName());
507+
}
508+
466509
// invalidate assumptions to not use an AST-inlined methods
467510
changedMethod(method.getName());
468511
if (refinedModule != null) {
@@ -733,17 +776,27 @@ public void newHierarchyVersion() {
733776
}
734777
}
735778

736-
public void newMethodsVersion(List<String> methodsToInvalidate) {
737-
for (String entryToInvalidate : methodsToInvalidate) {
738-
while (true) {
739-
final MethodEntry methodEntry = methods.get(entryToInvalidate);
740-
if (methodEntry == null) {
741-
break;
742-
} else {
743-
methodEntry.invalidate(rubyModule, entryToInvalidate);
744-
if (methods.replace(entryToInvalidate, methodEntry, methodEntry.withNewAssumption())) {
745-
break;
746-
}
779+
private void invalidateIncludedBy(String method) {
780+
for (RubyModule module : includedBy) {
781+
module.fields.newMethodVersion(method);
782+
}
783+
}
784+
785+
public void newMethodsVersion(Collection<String> methodsToInvalidate) {
786+
for (String name : methodsToInvalidate) {
787+
newMethodVersion(name);
788+
}
789+
}
790+
791+
private void newMethodVersion(String methodToInvalidate) {
792+
while (true) {
793+
final MethodEntry methodEntry = methods.get(methodToInvalidate);
794+
if (methodEntry == null) {
795+
return;
796+
} else {
797+
if (methods.replace(methodToInvalidate, methodEntry, methodEntry.withNewAssumption())) {
798+
methodEntry.invalidate(rubyModule, methodToInvalidate);
799+
return;
747800
}
748801
}
749802
}

0 commit comments

Comments
 (0)