Skip to content

Commit d89a4af

Browse files
committed
[SPARK-52675][ML][CONNECT] Interrupt hanging ML handlers in tests
### What changes were proposed in this pull request? Interrupt hanging ML handlers in tests: Recently some ML connect tests hangs randomly. We need to add code to interrupt hanging handler threads and print the stack trace, for debuggability. ### Why are the changes needed? Recently some ML connect tests hangs randomly. We need to add code to interrupt hanging handler threads and print the stack trace, for debuggability. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Manually. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #51364 from WeichenXu123/ml-connect-hang-debugger. Authored-by: Weichen Xu <weichen.xu@databricks.com> Signed-off-by: Weichen Xu <weichen.xu@databricks.com>
1 parent 83514c9 commit d89a4af

File tree

2 files changed

+87
-2
lines changed

2 files changed

+87
-2
lines changed

python/pyspark/testing/connectutils.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,9 @@ def master(cls):
166166

167167
@classmethod
168168
def setUpClass(cls):
169+
# This environment variable is for interrupting hanging ML-handler and making the
170+
# tests fail fast.
171+
os.environ["SPARK_CONNECT_ML_HANDLER_INTERRUPTION_TIMEOUT_MINUTES"] = "5"
169172
cls.spark = (
170173
PySparkSession.builder.config(conf=cls.conf())
171174
.appName(cls.__name__)

sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLHandler.scala

Lines changed: 84 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,14 @@
1717

1818
package org.apache.spark.sql.connect.ml
1919

20+
import java.io.{PrintWriter, StringWriter}
2021
import java.lang.ThreadLocal
22+
import java.util.concurrent.ConcurrentHashMap
2123

2224
import scala.collection.mutable
2325
import scala.jdk.CollectionConverters.CollectionHasAsScala
2426

27+
import org.apache.spark.SparkException
2528
import org.apache.spark.connect.proto
2629
import org.apache.spark.internal.Logging
2730
import org.apache.spark.ml.{Estimator, EstimatorUtils, Model, Transformer}
@@ -121,6 +124,9 @@ private[connect] object MLHandler extends Logging {
121124
override def initialValue: SessionHolder = null
122125
}
123126

127+
// A map of thread-id -> handler execution start time (UNIX timestamp)
128+
val handlerExecutionStartTimeMap = new ConcurrentHashMap[Long, Long]()
129+
124130
private val allowlistedMLClasses = {
125131
val transformerClasses = MLUtils.loadOperators(classOf[Transformer])
126132
val estimatorClasses = MLUtils.loadOperators(classOf[Estimator[_]])
@@ -150,7 +156,43 @@ private[connect] object MLHandler extends Logging {
150156
}
151157
}
152158

153-
def handleMlCommand(
159+
def startHangingHandlerReaper(): Unit = {
160+
val handlerInterruptionTimeoutMinutes = {
161+
try {
162+
val envValue = System.getenv("SPARK_CONNECT_ML_HANDLER_INTERRUPTION_TIMEOUT_MINUTES")
163+
if (envValue != null) {
164+
envValue.toInt
165+
} else { 0 }
166+
} catch {
167+
case _: Exception => 0
168+
}
169+
}
170+
171+
if (handlerInterruptionTimeoutMinutes > 0) {
172+
val handlerInterruptionTimeoutMillis = handlerInterruptionTimeoutMinutes * 60 * 1000
173+
val thread = new Thread(() => {
174+
while (true) {
175+
handlerExecutionStartTimeMap.forEach { (threadId, startTime) =>
176+
val execTime = System.currentTimeMillis() - startTime
177+
if (execTime > handlerInterruptionTimeoutMillis) {
178+
for (t <- Thread.getAllStackTraces().keySet().asScala) {
179+
if (t.getId() == threadId) {
180+
t.interrupt()
181+
}
182+
}
183+
}
184+
}
185+
Thread.sleep(60 * 1000)
186+
}
187+
})
188+
thread.setDaemon(true)
189+
thread.start()
190+
}
191+
}
192+
193+
startHangingHandlerReaper()
194+
195+
def _handleMlCommand(
154196
sessionHolder: SessionHolder,
155197
mlCommand: proto.MlCommand): proto.MlCommandResult = {
156198

@@ -410,6 +452,39 @@ private[connect] object MLHandler extends Logging {
410452
}
411453
}
412454

455+
def wrapHandler(
456+
originHandler: () => Any,
457+
reqProto: com.google.protobuf.GeneratedMessage): Any = {
458+
val threadId = Thread.currentThread().getId
459+
val startTime = System.currentTimeMillis()
460+
handlerExecutionStartTimeMap.put(threadId, startTime)
461+
try {
462+
originHandler()
463+
} catch {
464+
case e: InterruptedException =>
465+
val stackTrace = {
466+
val sw = new StringWriter()
467+
val pw = new PrintWriter(sw)
468+
e.printStackTrace(pw)
469+
sw.toString
470+
}
471+
val execTime = (System.currentTimeMillis() - startTime) / (60 * 1000)
472+
throw SparkException.internalError(
473+
s"The Spark Connect ML handler thread is interrupted after executing for " +
474+
s"$execTime minutes.\nThe request proto message is:\n${reqProto.toString}\n, " +
475+
s"the current stack trace is:\n$stackTrace\n")
476+
} finally {
477+
handlerExecutionStartTimeMap.remove(threadId)
478+
}
479+
}
480+
481+
def handleMlCommand(
482+
sessionHolder: SessionHolder,
483+
mlCommand: proto.MlCommand): proto.MlCommandResult = {
484+
wrapHandler(() => _handleMlCommand(sessionHolder, mlCommand), mlCommand)
485+
.asInstanceOf[proto.MlCommandResult]
486+
}
487+
413488
private def createModelSummary(
414489
sessionHolder: SessionHolder,
415490
createSummaryCmd: proto.MlCommand.CreateSummary): proto.MlCommandResult =
@@ -431,7 +506,9 @@ private[connect] object MLHandler extends Logging {
431506
.build()
432507
}
433508

434-
def transformMLRelation(relation: proto.MlRelation, sessionHolder: SessionHolder): DataFrame = {
509+
def _transformMLRelation(
510+
relation: proto.MlRelation,
511+
sessionHolder: SessionHolder): DataFrame = {
435512
relation.getMlTypeCase match {
436513
// Ml transform
437514
case proto.MlRelation.MlTypeCase.TRANSFORM =>
@@ -487,4 +564,9 @@ private[connect] object MLHandler extends Logging {
487564
case other => throw MlUnsupportedException(s"$other not supported")
488565
}
489566
}
567+
568+
def transformMLRelation(relation: proto.MlRelation, sessionHolder: SessionHolder): DataFrame = {
569+
wrapHandler(() => _transformMLRelation(relation, sessionHolder), relation)
570+
.asInstanceOf[DataFrame]
571+
}
490572
}

0 commit comments

Comments
 (0)