17
17
18
18
package org .apache .spark .sql .connect .ml
19
19
20
+ import java .io .{PrintWriter , StringWriter }
20
21
import java .lang .ThreadLocal
22
+ import java .util .concurrent .ConcurrentHashMap
21
23
22
24
import scala .collection .mutable
23
25
import scala .jdk .CollectionConverters .CollectionHasAsScala
24
26
27
+ import org .apache .spark .SparkException
25
28
import org .apache .spark .connect .proto
26
29
import org .apache .spark .internal .Logging
27
30
import org .apache .spark .ml .{Estimator , EstimatorUtils , Model , Transformer }
@@ -121,6 +124,9 @@ private[connect] object MLHandler extends Logging {
121
124
override def initialValue : SessionHolder = null
122
125
}
123
126
127
+ // A map of thread-id -> handler execution start time (UNIX timestamp)
128
+ val handlerExecutionStartTimeMap = new ConcurrentHashMap [Long , Long ]()
129
+
124
130
private val allowlistedMLClasses = {
125
131
val transformerClasses = MLUtils .loadOperators(classOf [Transformer ])
126
132
val estimatorClasses = MLUtils .loadOperators(classOf [Estimator [_]])
@@ -150,7 +156,43 @@ private[connect] object MLHandler extends Logging {
150
156
}
151
157
}
152
158
153
- def handleMlCommand (
159
+ def startHangingHandlerReaper (): Unit = {
160
+ val handlerInterruptionTimeoutMinutes = {
161
+ try {
162
+ val envValue = System .getenv(" SPARK_CONNECT_ML_HANDLER_INTERRUPTION_TIMEOUT_MINUTES" )
163
+ if (envValue != null ) {
164
+ envValue.toInt
165
+ } else { 0 }
166
+ } catch {
167
+ case _ : Exception => 0
168
+ }
169
+ }
170
+
171
+ if (handlerInterruptionTimeoutMinutes > 0 ) {
172
+ val handlerInterruptionTimeoutMillis = handlerInterruptionTimeoutMinutes * 60 * 1000
173
+ val thread = new Thread (() => {
174
+ while (true ) {
175
+ handlerExecutionStartTimeMap.forEach { (threadId, startTime) =>
176
+ val execTime = System .currentTimeMillis() - startTime
177
+ if (execTime > handlerInterruptionTimeoutMillis) {
178
+ for (t <- Thread .getAllStackTraces().keySet().asScala) {
179
+ if (t.getId() == threadId) {
180
+ t.interrupt()
181
+ }
182
+ }
183
+ }
184
+ }
185
+ Thread .sleep(60 * 1000 )
186
+ }
187
+ })
188
+ thread.setDaemon(true )
189
+ thread.start()
190
+ }
191
+ }
192
+
193
+ startHangingHandlerReaper()
194
+
195
+ def _handleMlCommand (
154
196
sessionHolder : SessionHolder ,
155
197
mlCommand : proto.MlCommand ): proto.MlCommandResult = {
156
198
@@ -410,6 +452,39 @@ private[connect] object MLHandler extends Logging {
410
452
}
411
453
}
412
454
455
+ def wrapHandler (
456
+ originHandler : () => Any ,
457
+ reqProto : com.google.protobuf.GeneratedMessage ): Any = {
458
+ val threadId = Thread .currentThread().getId
459
+ val startTime = System .currentTimeMillis()
460
+ handlerExecutionStartTimeMap.put(threadId, startTime)
461
+ try {
462
+ originHandler()
463
+ } catch {
464
+ case e : InterruptedException =>
465
+ val stackTrace = {
466
+ val sw = new StringWriter ()
467
+ val pw = new PrintWriter (sw)
468
+ e.printStackTrace(pw)
469
+ sw.toString
470
+ }
471
+ val execTime = (System .currentTimeMillis() - startTime) / (60 * 1000 )
472
+ throw SparkException .internalError(
473
+ s " The Spark Connect ML handler thread is interrupted after executing for " +
474
+ s " $execTime minutes. \n The request proto message is: \n ${reqProto.toString}\n , " +
475
+ s " the current stack trace is: \n $stackTrace\n " )
476
+ } finally {
477
+ handlerExecutionStartTimeMap.remove(threadId)
478
+ }
479
+ }
480
+
481
+ def handleMlCommand (
482
+ sessionHolder : SessionHolder ,
483
+ mlCommand : proto.MlCommand ): proto.MlCommandResult = {
484
+ wrapHandler(() => _handleMlCommand(sessionHolder, mlCommand), mlCommand)
485
+ .asInstanceOf [proto.MlCommandResult ]
486
+ }
487
+
413
488
private def createModelSummary (
414
489
sessionHolder : SessionHolder ,
415
490
createSummaryCmd : proto.MlCommand .CreateSummary ): proto.MlCommandResult =
@@ -431,7 +506,9 @@ private[connect] object MLHandler extends Logging {
431
506
.build()
432
507
}
433
508
434
- def transformMLRelation (relation : proto.MlRelation , sessionHolder : SessionHolder ): DataFrame = {
509
+ def _transformMLRelation (
510
+ relation : proto.MlRelation ,
511
+ sessionHolder : SessionHolder ): DataFrame = {
435
512
relation.getMlTypeCase match {
436
513
// Ml transform
437
514
case proto.MlRelation .MlTypeCase .TRANSFORM =>
@@ -487,4 +564,9 @@ private[connect] object MLHandler extends Logging {
487
564
case other => throw MlUnsupportedException (s " $other not supported " )
488
565
}
489
566
}
567
+
568
+ def transformMLRelation (relation : proto.MlRelation , sessionHolder : SessionHolder ): DataFrame = {
569
+ wrapHandler(() => _transformMLRelation(relation, sessionHolder), relation)
570
+ .asInstanceOf [DataFrame ]
571
+ }
490
572
}
0 commit comments