Skip to content

Commit 8eef09d

Browse files
committed
Introduce predicates for SafepointManager to filter which Threads and Fibers to run the action on
* Avoids adding RubyThread#pendingSafepointActions which end up doing nothing or be executed in the wrong Fiber of that Thread.
1 parent 28d71cd commit 8eef09d

File tree

6 files changed

+82
-51
lines changed

6 files changed

+82
-51
lines changed

src/main/java/org/truffleruby/core/VMPrimitiveNodes.java

Lines changed: 3 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,6 @@
5252
import org.truffleruby.core.cast.NameToJavaStringNode;
5353
import org.truffleruby.core.cast.ToRubyIntegerNode;
5454
import org.truffleruby.core.exception.RubyException;
55-
import org.truffleruby.core.fiber.FiberManager;
5655
import org.truffleruby.core.klass.RubyClass;
5756
import org.truffleruby.core.method.RubyMethod;
5857
import org.truffleruby.core.module.RubyModule;
@@ -64,8 +63,8 @@
6463
import org.truffleruby.core.string.RubyString;
6564
import org.truffleruby.core.string.StringNodes.MakeStringNode;
6665
import org.truffleruby.core.thread.RubyThread;
67-
import org.truffleruby.core.thread.ThreadManager;
6866
import org.truffleruby.language.RubyDynamicObject;
67+
import org.truffleruby.language.SafepointPredicate;
6968
import org.truffleruby.language.backtrace.Backtrace;
7069
import org.truffleruby.language.backtrace.BacktraceFormatter;
7170
import org.truffleruby.language.control.ExitException;
@@ -270,8 +269,6 @@ protected boolean watchSignalProc(Object signalString, RubyProc action,
270269
}
271270

272271
final RubyThread rootThread = context.getThreadManager().getRootThread();
273-
final FiberManager fiberManager = rootThread.fiberManager;
274-
final ThreadManager threadManager = context.getThreadManager();
275272

276273
// Workaround: we need to register with Truffle (which means going multithreaded),
277274
// so that NFI can get its context to call pthread_kill() (GR-7405).
@@ -296,13 +293,8 @@ protected boolean watchSignalProc(Object signalString, RubyProc action,
296293
try {
297294
context.getSafepointManager().pauseAllThreadsAndExecuteFromNonRubyThread(
298295
"Handling of signal " + signal,
299-
(rubyThread, currentNode) -> {
300-
if (rubyThread == rootThread &&
301-
threadManager.getRubyFiberFromCurrentJavaThread() == fiberManager
302-
.getCurrentFiber()) {
303-
ProcOperations.rootCall(action);
304-
}
305-
});
296+
SafepointPredicate.currentFiberOfThread(context, rootThread),
297+
(rubyThread, currentNode) -> ProcOperations.rootCall(action));
306298
} finally {
307299
truffleContext.leave(this, prev);
308300
}

src/main/java/org/truffleruby/core/thread/RubyThread.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ public class RubyThread extends RubyDynamicObject implements ObjectGraphNode {
4444
public volatile InterruptMode interruptMode;
4545
public volatile ThreadStatus status;
4646
public final List<Lock> ownedLocks;
47-
public FiberManager fiberManager;
47+
public final FiberManager fiberManager;
4848
CountDownLatch finishedLatch;
4949
final RubyHash threadLocalVariables;
5050
final RubyHash recursiveObjects;

src/main/java/org/truffleruby/core/thread/ThreadManager.java

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -667,17 +667,15 @@ private void doKillOtherThreads() {
667667
while (true) {
668668
try {
669669
final String reason = "kill other threads for shutdown";
670-
context.getSafepointManager().pauseAllThreadsAndExecute(reason, null, (thread, currentNode) -> {
671-
if (Thread.currentThread() != initiatingJavaThread) {
672-
final FiberManager fiberManager = thread.fiberManager;
673-
final RubyFiber fiber = getRubyFiberFromCurrentJavaThread();
674-
675-
if (fiberManager.getCurrentFiber() == fiber) {
670+
context.getSafepointManager().pauseAllThreadsAndExecute(
671+
reason,
672+
null,
673+
thread -> Thread.currentThread() != initiatingJavaThread &&
674+
getRubyFiberFromCurrentJavaThread() == thread.fiberManager.getCurrentFiber(),
675+
(thread, currentNode) -> {
676676
thread.status = ThreadStatus.ABORTING;
677677
throw new KillException();
678-
}
679-
}
680-
});
678+
});
681679
break; // Successfully executed the safepoint and sent the exceptions.
682680
} catch (RaiseException e) {
683681
final RubyException rubyException = e.getException();

src/main/java/org/truffleruby/language/SafepointManager.java

Lines changed: 29 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ protected boolean onAdvance(int phase, int registeredParties) {
6060
* per context. */
6161
private volatile boolean active = false;
6262

63+
private volatile SafepointPredicate filter;
6364
private volatile SafepointAction action;
6465

6566
public SafepointManager(RubyContext context, RubyLanguage language) {
@@ -155,12 +156,16 @@ private SafepointAction step(Node currentNode, boolean isDrivingThread, String r
155156
// Wait for the assumption to be renewed
156157
phaser.arriveAndAwaitAdvance();
157158

158-
// Read these while in the safepoint
159-
final SafepointAction deferredAction = SafepointAction.isDeferred(action) ? action : null;
159+
// Run the filter and read the action field while in the safepoint
160+
SafepointAction deferredAction = null;
160161

161162
try {
162-
if (action instanceof SafepointAction.Pure) { // same as !deferred
163-
action.accept(thread, currentNode);
163+
if (filter.test(thread)) {
164+
if (SafepointAction.isDeferred(action)) {
165+
deferredAction = action;
166+
} else {
167+
action.accept(thread, currentNode);
168+
}
164169
}
165170
} finally {
166171
// Wait for other threads to finish their action
@@ -256,7 +261,8 @@ private void restoreDefaultInterruptHandler() {
256261
}
257262

258263
@TruffleBoundary
259-
public void pauseAllThreadsAndExecute(String reason, Node currentNode, SafepointAction action) {
264+
public void pauseAllThreadsAndExecute(String reason, Node currentNode, SafepointPredicate filter,
265+
SafepointAction action) {
260266
if (lock.isHeldByCurrentThread()) {
261267
throw new IllegalStateException("Re-entered SafepointManager");
262268
}
@@ -266,20 +272,22 @@ public void pauseAllThreadsAndExecute(String reason, Node currentNode, Safepoint
266272
poll(language, currentNode);
267273
}
268274

275+
final SafepointAction deferredAction;
269276
try {
270-
pauseAllThreadsAndExecuteInternal(reason, currentNode, action);
277+
deferredAction = pauseAllThreadsAndExecuteInternal(reason, currentNode, filter, action);
271278
} finally {
272279
lock.unlock();
273280
}
274281

275282
// Run deferred actions after leaving the SafepointManager lock.
276-
if (SafepointAction.isDeferred(action)) {
283+
if (deferredAction != null) {
277284
action.accept(context.getThreadManager().getCurrentThread(), currentNode);
278285
}
279286
}
280287

281288
@TruffleBoundary
282-
public void pauseAllThreadsAndExecuteFromNonRubyThread(String reason, SafepointAction action) {
289+
public void pauseAllThreadsAndExecuteFromNonRubyThread(String reason, SafepointPredicate filter,
290+
SafepointAction action) {
283291
if (lock.isHeldByCurrentThread()) {
284292
throw new IllegalStateException("Re-entered SafepointManager");
285293
}
@@ -292,7 +300,7 @@ public void pauseAllThreadsAndExecuteFromNonRubyThread(String reason, SafepointA
292300
try {
293301
enterThread();
294302
try {
295-
pauseAllThreadsAndExecuteInternal(reason, null, action);
303+
pauseAllThreadsAndExecuteInternal(reason, null, filter, action);
296304
} finally {
297305
leaveThread();
298306
}
@@ -318,27 +326,20 @@ public void pauseRubyThreadAndExecute(String reason, RubyThread rubyThread, Node
318326
// fast path if we are already the right thread
319327
action.accept(rubyThread, currentNode);
320328
} else {
321-
final SafepointAction filteredAction;
322-
if (action instanceof SafepointAction.Pure) {
323-
filteredAction = (SafepointAction.Pure) (thread, threadCurrentNode) -> {
324-
if (thread == rubyThread &&
325-
threadManager.getRubyFiberFromCurrentJavaThread() == fiberManager.getCurrentFiber()) {
326-
action.accept(thread, threadCurrentNode);
327-
}
328-
};
329-
} else {
330-
filteredAction = (thread, threadCurrentNode) -> {
331-
if (thread == rubyThread &&
332-
threadManager.getRubyFiberFromCurrentJavaThread() == fiberManager.getCurrentFiber()) {
333-
action.accept(thread, threadCurrentNode);
334-
}
335-
};
336-
}
337-
pauseAllThreadsAndExecute(reason, currentNode, filteredAction);
329+
pauseAllThreadsAndExecute(
330+
reason,
331+
currentNode,
332+
SafepointPredicate.currentFiberOfThread(context, rubyThread),
333+
action);
338334
}
339335
}
340336

341-
private void pauseAllThreadsAndExecuteInternal(String reason, Node currentNode, SafepointAction action) {
337+
private SafepointAction pauseAllThreadsAndExecuteInternal(String reason, Node currentNode,
338+
SafepointPredicate filter,
339+
SafepointAction action) {
340+
assert lock.isHeldByCurrentThread();
341+
342+
this.filter = filter;
342343
this.action = action;
343344

344345
/* We need to invalidate first so the interrupted threads see the invalidation in poll() in their
@@ -347,7 +348,7 @@ private void pauseAllThreadsAndExecuteInternal(String reason, Node currentNode,
347348
language.invalidateSafepointAssumption(reason);
348349
interruptOtherThreads();
349350

350-
step(currentNode, true, reason);
351+
return step(currentNode, true, reason);
351352
}
352353

353354
private void interruptOtherThreads() {
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
/*
2+
* Copyright (c) 2021 Oracle and/or its affiliates. All rights reserved. This
3+
* code is released under a tri EPL/GPL/LGPL license. You can use it,
4+
* redistribute it and/or modify it under the terms of the:
5+
*
6+
* Eclipse Public License version 2.0, or
7+
* GNU General Public License version 2, or
8+
* GNU Lesser General Public License version 2.1.
9+
*/
10+
package org.truffleruby.language;
11+
12+
import org.truffleruby.RubyContext;
13+
import org.truffleruby.core.fiber.FiberManager;
14+
import org.truffleruby.core.thread.RubyThread;
15+
import org.truffleruby.core.thread.ThreadManager;
16+
17+
import java.util.function.Predicate;
18+
19+
public interface SafepointPredicate extends Predicate<RubyThread> {
20+
21+
static final SafepointPredicate ALL_THREADS_AND_FIBERS = rubyThread -> true;
22+
23+
static SafepointPredicate currentFiberOfThread(RubyContext context, RubyThread targetThread) {
24+
final ThreadManager threadManager = context.getThreadManager();
25+
final FiberManager fiberManager = targetThread.fiberManager;
26+
27+
return thread -> thread == targetThread &&
28+
threadManager.getRubyFiberFromCurrentJavaThread() == fiberManager.getCurrentFiber();
29+
}
30+
31+
}

src/main/java/org/truffleruby/language/objects/ObjectGraph.java

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import org.truffleruby.language.ImmutableRubyObject;
2222
import org.truffleruby.language.RubyDynamicObject;
2323
import org.truffleruby.language.SafepointAction;
24+
import org.truffleruby.language.SafepointPredicate;
2425
import org.truffleruby.language.arguments.RubyArguments;
2526

2627
import com.oracle.truffle.api.CompilerDirectives.TruffleBoundary;
@@ -76,7 +77,11 @@ public static Set<Object> stopAndGetAllObjects(String reason, RubyContext contex
7677
}
7778
};
7879

79-
context.getSafepointManager().pauseAllThreadsAndExecute(reason, currentNode, action);
80+
context.getSafepointManager().pauseAllThreadsAndExecute(
81+
reason,
82+
currentNode,
83+
SafepointPredicate.ALL_THREADS_AND_FIBERS,
84+
action);
8085

8186
return visited;
8287
}
@@ -97,7 +102,11 @@ public static Set<Object> stopAndGetRootObjects(String reason, RubyContext conte
97102
}
98103
};
99104

100-
context.getSafepointManager().pauseAllThreadsAndExecute(reason, currentNode, action);
105+
context.getSafepointManager().pauseAllThreadsAndExecute(
106+
reason,
107+
currentNode,
108+
SafepointPredicate.ALL_THREADS_AND_FIBERS,
109+
action);
101110

102111
return visited;
103112
}

0 commit comments

Comments
 (0)