12
12
import java .util .ArrayDeque ;
13
13
import java .util .ArrayList ;
14
14
import java .util .Collection ;
15
+ import java .util .Collections ;
15
16
import java .util .Deque ;
16
17
import java .util .HashMap ;
17
18
import java .util .HashSet ;
21
22
import java .util .Map .Entry ;
22
23
import java .util .NoSuchElementException ;
23
24
import java .util .Set ;
25
+ import java .util .WeakHashMap ;
24
26
import java .util .concurrent .ConcurrentHashMap ;
25
27
import java .util .concurrent .ConcurrentMap ;
26
28
import java .util .concurrent .locks .ReentrantLock ;
@@ -113,6 +115,17 @@ public static void debugModuleChain(RubyModule module) {
113
115
// Concurrency: only modified during boot
114
116
private final Map <String , Assumption > inlinedBuiltinsAssumptions = new HashMap <>();
115
117
118
+ /** A weak set of all classes include-ing or prepend-ing this module M, and of modules which where already prepended
119
+ * before to such classes. Used by method invalidation so that lookups before the include/prepend is done can be
120
+ * invalidated when a method is added later in M. When M is an included module, defining a method on M only needs to
121
+ * invalidate classes in which M is included. However, when M is a prepended module defining a method on M needs to
122
+ * invalidate classes and other modules which were prepended before M (since the method could be defined there, and
123
+ * then the assumption on M would not be checked). Not that there is no need to be transitive here, include/prepend
124
+ * snapshot the modules to include alongside M, and later M.include/M.prepend have no effect on classes already
125
+ * including M. In other words,
126
+ * {@code module A; end; module B; end; class C; end; C.include A; A.include B; C.ancestors.include?(B) => false} */
127
+ private final Set <RubyModule > includedBy ;
128
+
116
129
public ModuleFields (
117
130
RubyLanguage language ,
118
131
SourceSection sourceSection ,
@@ -128,6 +141,9 @@ public ModuleFields(
128
141
this .constantsUnmodifiedAssumption = new CyclicAssumption ("constants are unmodified" );
129
142
classVariables = new ClassVariableStorage (language );
130
143
start = new PrependMarker (this );
144
+ this .includedBy = rubyModule instanceof RubyClass
145
+ ? null
146
+ : Collections .newSetFromMap (Collections .synchronizedMap (new WeakHashMap <>()));
131
147
}
132
148
133
149
public RubyConstant getAdoptedByLexicalParent (
@@ -295,11 +311,14 @@ public void include(RubyContext context, Node currentNode, RubyModule module) {
295
311
296
312
private void performIncludes (ModuleChain inclusionPoint , Deque <RubyModule > moduleAncestors ) {
297
313
while (!moduleAncestors .isEmpty ()) {
298
- RubyModule mod = moduleAncestors .pop ();
299
- inclusionPoint .insertAfter (mod );
300
- // Module#include only adds modules between the current class and the super class,
301
- // so invalidating the current class is enough as all affected lookups would go through the current class.
302
- newMethodsVersion (mod .fields .getMethodNames ());
314
+ RubyModule toInclude = moduleAncestors .pop ();
315
+ inclusionPoint .insertAfter (toInclude );
316
+ if (rubyModule instanceof RubyClass ) { // M.include(N) just registers N but does nothing until C.include/prepend(M)
317
+ toInclude .fields .includedBy .add (rubyModule );
318
+ // Module#include only adds modules between the current class and the super class,
319
+ // so invalidating the current class is enough as all affected lookups would go through the current class.
320
+ newMethodsVersion (toInclude .fields .getMethodNames ());
321
+ }
303
322
}
304
323
}
305
324
@@ -327,31 +346,50 @@ public void prepend(RubyContext context, Node currentNode, RubyModule module) {
327
346
328
347
SharedObjects .propagate (context .getLanguageSlow (), rubyModule , module );
329
348
330
- // Previous calls on instances of the current class must have looked up through the first prepended module,
331
- // so invalidate that one.
332
- final ModuleFields moduleFieldsToInvalidate = getFirstModuleChain ().getActualModule ().fields ;
349
+ /* We need to invalidate all prepended modules and the class, because call sites which looked up methods before
350
+ * only check the class or one of the prepend module (if the method is defined there). */
351
+ final List <RubyModule > prependedModulesAndClass = rubyModule instanceof RubyClass
352
+ ? getPrependedModulesAndClass ()
353
+ : null ;
333
354
334
355
ModuleChain mod = module .fields .start ;
335
356
ModuleChain cur = start ;
336
357
while (mod != null &&
337
358
!(mod instanceof ModuleFields && ((ModuleFields ) mod ).rubyModule instanceof RubyClass )) {
338
359
if (!(mod instanceof PrependMarker )) {
339
- final RubyModule actualModule = mod .getActualModule ();
340
- if (!ModuleOperations .includesModule (rubyModule , actualModule )) {
341
- cur .insertAfter (actualModule );
342
- moduleFieldsToInvalidate .newMethodsVersion (actualModule .fields .getMethodNames ());
360
+ final RubyModule toPrepend = mod .getActualModule ();
361
+ if (!ModuleOperations .includesModule (rubyModule , toPrepend )) {
362
+ cur .insertAfter (toPrepend );
363
+ if (rubyModule instanceof RubyClass ) { // M.prepend(N) just registers N but does nothing until C.prepend/include(M)
364
+ final List <String > methodsToInvalidate = toPrepend .fields .getMethodNames ();
365
+ for (RubyModule moduleToInvalidate : prependedModulesAndClass ) {
366
+ toPrepend .fields .includedBy .add (moduleToInvalidate );
367
+ moduleToInvalidate .fields .newMethodsVersion (methodsToInvalidate );
368
+ }
369
+ }
343
370
cur = cur .getParentModule ();
344
371
}
345
372
}
346
373
mod = mod .getParentModule ();
347
374
}
348
375
349
376
// If there were already prepended modules, invalidate the first of them
350
- moduleFieldsToInvalidate . newHierarchyVersion ();
377
+ newHierarchyVersion ();
351
378
352
379
invalidateBuiltinsAssumptions ();
353
380
}
354
381
382
+ private List <RubyModule > getPrependedModulesAndClass () {
383
+ final List <RubyModule > prependedModulesAndClass = new ArrayList <>();
384
+ ModuleChain chain = getFirstModuleChain ();
385
+ while (chain != this ) {
386
+ prependedModulesAndClass .add (chain .getActualModule ());
387
+ chain = chain .getParentModule ();
388
+ }
389
+ prependedModulesAndClass .add (rubyModule );
390
+ return prependedModulesAndClass ;
391
+ }
392
+
355
393
/** Set the value of a constant, possibly redefining it. */
356
394
@ TruffleBoundary
357
395
public RubyConstant setConstant (RubyContext context , Node currentNode , String name , Object value ) {
@@ -463,6 +501,11 @@ public void addMethod(RubyContext context, Node currentNode, InternalMethod meth
463
501
if (previousMethodEntry != null ) {
464
502
previousMethodEntry .invalidate (rubyModule , method .getName ());
465
503
}
504
+
505
+ if (includedBy != null && !includedBy .isEmpty ()) {
506
+ invalidateIncludedBy (method .getName ());
507
+ }
508
+
466
509
// invalidate assumptions to not use an AST-inlined methods
467
510
changedMethod (method .getName ());
468
511
if (refinedModule != null ) {
@@ -733,17 +776,27 @@ public void newHierarchyVersion() {
733
776
}
734
777
}
735
778
736
- public void newMethodsVersion (List <String > methodsToInvalidate ) {
737
- for (String entryToInvalidate : methodsToInvalidate ) {
738
- while (true ) {
739
- final MethodEntry methodEntry = methods .get (entryToInvalidate );
740
- if (methodEntry == null ) {
741
- break ;
742
- } else {
743
- methodEntry .invalidate (rubyModule , entryToInvalidate );
744
- if (methods .replace (entryToInvalidate , methodEntry , methodEntry .withNewAssumption ())) {
745
- break ;
746
- }
779
+ private void invalidateIncludedBy (String method ) {
780
+ for (RubyModule module : includedBy ) {
781
+ module .fields .newMethodVersion (method );
782
+ }
783
+ }
784
+
785
+ public void newMethodsVersion (Collection <String > methodsToInvalidate ) {
786
+ for (String name : methodsToInvalidate ) {
787
+ newMethodVersion (name );
788
+ }
789
+ }
790
+
791
+ private void newMethodVersion (String methodToInvalidate ) {
792
+ while (true ) {
793
+ final MethodEntry methodEntry = methods .get (methodToInvalidate );
794
+ if (methodEntry == null ) {
795
+ return ;
796
+ } else {
797
+ if (methods .replace (methodToInvalidate , methodEntry , methodEntry .withNewAssumption ())) {
798
+ methodEntry .invalidate (rubyModule , methodToInvalidate );
799
+ return ;
747
800
}
748
801
}
749
802
}
0 commit comments