Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -448,13 +448,6 @@ def magic_matplot(name):
}


# get or create spark session
spark_session = kyuubi_util.get_spark_session(
os.environ.get("KYUUBI_SPARK_SESSION_UUID")
)
global_dict["spark"] = spark_session


def main():
sys_stdin = sys.stdin
sys_stdout = sys.stdout
Expand Down Expand Up @@ -487,6 +480,12 @@ def main():
if content["cmd"] == "exit_worker":
break

if content["cmd"] == "set_session_id":
# get or create spark session
global_dict["spark"] = kyuubi_util.get_spark_session(content["session_id"])
continue


result = execute_request(content)

try:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ import org.apache.kyuubi.config.KyuubiReservedKeys.{KYUUBI_ENGINE_SUBMIT_TIME_KE
import org.apache.kyuubi.engine.ShareLevel
import org.apache.kyuubi.engine.spark.SparkSQLEngine.{countDownLatch, currentEngine}
import org.apache.kyuubi.engine.spark.events.{EngineEvent, EngineEventsStore, SparkEventHandlerRegister}
import org.apache.kyuubi.engine.spark.operation.PythonWorkerPool
import org.apache.kyuubi.engine.spark.session.{SparkSessionImpl, SparkSQLSessionManager}
import org.apache.kyuubi.events.EventBus
import org.apache.kyuubi.ha.HighAvailabilityConf._
Expand Down Expand Up @@ -68,6 +69,7 @@ case class SparkSQLEngine(spark: SparkSession) extends Serverable("SparkSQLEngin
val engineEventListener = new SparkSQLEngineEventListener(kvStore, conf)
spark.sparkContext.addSparkListener(engineEventListener)
super.initialize(conf)
PythonWorkerPool.init(spark, conf)
}

override def start(): Unit = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import java.io.{BufferedReader, File, FilenameFilter, FileOutputStream, InputStr
import java.lang.ProcessBuilder.Redirect
import java.net.URI
import java.nio.file.{Files, Path, Paths}
import java.util.concurrent.RejectedExecutionException
import java.util.concurrent.{LinkedBlockingQueue, RejectedExecutionException}
import java.util.concurrent.atomic.AtomicBoolean
import java.util.concurrent.locks.ReentrantLock

Expand All @@ -34,7 +34,8 @@ import org.apache.spark.sql.{Row, SparkSession}
import org.apache.spark.sql.types.StructType

import org.apache.kyuubi.{KyuubiSQLException, Logging, Utils}
import org.apache.kyuubi.config.KyuubiConf.{ENGINE_SPARK_PYTHON_ENV_ARCHIVE, ENGINE_SPARK_PYTHON_ENV_ARCHIVE_EXEC_PATH, ENGINE_SPARK_PYTHON_HOME_ARCHIVE, ENGINE_SPARK_PYTHON_MAGIC_ENABLED}
import org.apache.kyuubi.config.KyuubiConf
import org.apache.kyuubi.config.KyuubiConf.{ENGINE_SPARK_PYTHON_ENV_ARCHIVE, ENGINE_SPARK_PYTHON_ENV_ARCHIVE_EXEC_PATH, ENGINE_SPARK_PYTHON_HOME_ARCHIVE, ENGINE_SPARK_PYTHON_MAGIC_ENABLED, PYTHON_WORKER_MAX_POOL_SIZE}
import org.apache.kyuubi.config.KyuubiConf.EngineSparkOutputMode.{AUTO, EngineSparkOutputMode, NOTEBOOK}
import org.apache.kyuubi.config.KyuubiReservedKeys.{KYUUBI_SESSION_USER_KEY, KYUUBI_STATEMENT_ID_KEY}
import org.apache.kyuubi.engine.spark.KyuubiSparkUtil._
Expand Down Expand Up @@ -185,14 +186,17 @@ class ExecutePython(
}

case class SessionPythonWorker(
errorReader: Thread,
pythonWorkerMonitor: Thread,
errorReaderFunc: String => Thread,
pythonWorkerMonitorFunc: String => Thread,
workerProcess: Process) {
private val stdin: PrintWriter = new PrintWriter(workerProcess.getOutputStream)
private val stdout: BufferedReader =
new BufferedReader(new InputStreamReader(workerProcess.getInputStream), 1)
private val lock = new ReentrantLock()

private var errorReader: Option[Thread] = None
private var pythonWorkerMonitor: Option[Thread] = None

private def withLockRequired[T](block: => T): T = Utils.withLockRequired(lock)(block)

/**
Expand All @@ -207,10 +211,7 @@ case class SessionPythonWorker(
* @return the python response
*/
def runCode(code: String, internal: Boolean = false): Option[PythonResponse] = withLockRequired {
if (!workerProcess.isAlive) {
throw KyuubiSQLException("Python worker process has been exited, please check the error log" +
" and re-create the session to run python code.")
}
checkPythonWorkerStatus()
val input = JsonUtils.toJson(Map("code" -> code, "cmd" -> "run_code"))
// scalastyle:off println
stdin.println(input)
Expand All @@ -224,6 +225,24 @@ case class SessionPythonWorker(
pythonResponse
}

def checkPythonWorkerStatus(): Unit = {
if (!workerProcess.isAlive) {
throw KyuubiSQLException("Python worker process has been exited, please check the error log" +
" and re-create the session to run python code.")
}
}

def setSessionId(sessionId: String): Unit = {
checkPythonWorkerStatus()
errorReader = Some(errorReaderFunc.apply(sessionId))
pythonWorkerMonitor = Some(pythonWorkerMonitorFunc.apply(sessionId))
val input = JsonUtils.toJson(Map("session_id" -> sessionId, "cmd" -> "set_session_id"))
// scalastyle:off println
stdin.println(input)
// scalastyle:on
stdin.flush()
}

def close(): Unit = {
val exitCmd = JsonUtils.toJson(Map("cmd" -> "exit_worker"))
// scalastyle:off println
Expand All @@ -232,8 +251,8 @@ case class SessionPythonWorker(
stdin.flush()
stdin.close()
stdout.close()
errorReader.interrupt()
pythonWorkerMonitor.interrupt()
errorReader.foreach(_.interrupt())
pythonWorkerMonitor.foreach(_.interrupt())
workerProcess.destroy()
}

Expand All @@ -253,6 +272,51 @@ case class SessionPythonWorker(
}
}

object PythonWorkerPool extends Logging {

private var pythonWorkerPool: PythonWorkerPool = _

def init(spark: SparkSession, kyuubiConf: KyuubiConf): Unit = {
ExecutePython.init()
val poolSize = kyuubiConf.get(PYTHON_WORKER_MAX_POOL_SIZE)
pythonWorkerPool = new PythonWorkerPool(poolSize, spark)
}

def getPythonWorker(session: Session): SessionPythonWorker = {
pythonWorkerPool.getPythonWorker(session)
}
}

class PythonWorkerPool(poolSize: Int, spark: SparkSession) extends Logging {
@volatile private var stopped = false
private val pythonWorkerPool = new LinkedBlockingQueue[SessionPythonWorker](poolSize)
private val pythonWorkerProducer = new Thread("python worker producer") {
override def run(): Unit = {
while (!stopped) {
logger.info(s"start create python worker")
pythonWorkerPool.put(ExecutePython.createSessionPythonWorker(spark))
logger.info(s"end create python worker")
}
}
}
pythonWorkerProducer.setDaemon(true)
pythonWorkerProducer.start()

def close(): Unit = {
stopped = true
pythonWorkerProducer.interrupt()
}

def getPythonWorker(session: Session): SessionPythonWorker = {
logger.info(s"start get python worker")
val pythonWorker = pythonWorkerPool.take()
logger.info(s"end get python worker")
val sessionId = session.handle.identifier.toString
pythonWorker.setSessionId(sessionId)
pythonWorker
}
}

object ExecutePython extends Logging {
final val DEFAULT_SPARK_PYTHON_HOME_ARCHIVE_FRAGMENT = "__kyuubi_spark_python_home__"
final val DEFAULT_SPARK_PYTHON_ENV_ARCHIVE_FRAGMENT = "__kyuubi_spark_python_env__"
Expand All @@ -276,14 +340,13 @@ object ExecutePython extends Logging {
}
}

def createSessionPythonWorker(spark: SparkSession, session: Session): SessionPythonWorker = {
val sessionId = session.handle.identifier.toString
def createSessionPythonWorker(spark: SparkSession): SessionPythonWorker = {
val pythonExec = StringUtils.firstNonBlank(
spark.conf.getOption("spark.pyspark.driver.python").orNull,
spark.conf.getOption("spark.pyspark.python").orNull,
System.getenv("PYSPARK_DRIVER_PYTHON"),
System.getenv("PYSPARK_PYTHON"),
getSparkPythonExecFromArchive(spark, session).getOrElse("python3"))
getSparkPythonExecFromArchive(spark).getOrElse("python3"))

val builder = new ProcessBuilder(Seq(
pythonExec,
Expand All @@ -303,9 +366,8 @@ object ExecutePython extends Logging {
"SPARK_HOME",
sys.env.getOrElse(
"SPARK_HOME",
getSparkPythonHomeFromArchive(spark, session).getOrElse(defaultSparkHome)))
getSparkPythonHomeFromArchive(spark).getOrElse(defaultSparkHome)))
}
env.put("KYUUBI_SPARK_SESSION_UUID", sessionId)
env.put("PYTHON_GATEWAY_CONNECTION_INFO", KyuubiPythonGatewayServer.CONNECTION_FILE_PATH)
env.put(MAGIC_ENABLED, getSessionConf(ENGINE_SPARK_PYTHON_MAGIC_ENABLED, spark).toString)
logger.info(
Expand All @@ -317,12 +379,12 @@ object ExecutePython extends Logging {
builder.redirectError(Redirect.PIPE)
val process = builder.start()
SessionPythonWorker(
startStderrSteamReader(process, sessionId),
startWatcher(process, sessionId),
sessionId => startStderrSteamReader(process, sessionId),
sessionId => startWatcher(process, sessionId),
process)
}

def getSparkPythonExecFromArchive(spark: SparkSession, session: Session): Option[String] = {
def getSparkPythonExecFromArchive(spark: SparkSession): Option[String] = {
val pythonEnvArchive = getSessionConf(ENGINE_SPARK_PYTHON_ENV_ARCHIVE, spark)
val pythonEnvExecPath = getSessionConf(ENGINE_SPARK_PYTHON_ENV_ARCHIVE_EXEC_PATH, spark)
pythonEnvArchive.map {
Expand All @@ -336,7 +398,7 @@ object ExecutePython extends Logging {
}.find(Files.exists(_)).map(_.toAbsolutePath.toFile.getCanonicalPath)
}

def getSparkPythonHomeFromArchive(spark: SparkSession, session: Session): Option[String] = {
def getSparkPythonHomeFromArchive(spark: SparkSession): Option[String] = {
val pythonHomeArchive = getSessionConf(ENGINE_SPARK_PYTHON_HOME_ARCHIVE, spark)
pythonHomeArchive.map {
archive =>
Expand Down Expand Up @@ -365,7 +427,7 @@ object ExecutePython extends Logging {
}
}

private def startStderrSteamReader(process: Process, sessionId: String): Thread = {
private def startStderrSteamReader(process: Process, sessionId: String = ""): Thread = {
val stderrThread = new Thread(s"session[$sessionId] process stderr thread") {
override def run(): Unit = {
val lines = scala.io.Source.fromInputStream(process.getErrorStream).getLines()
Expand All @@ -377,7 +439,7 @@ object ExecutePython extends Logging {
stderrThread
}

def startWatcher(process: Process, sessionId: String): Thread = {
def startWatcher(process: Process, sessionId: String = ""): Thread = {
val processWatcherThread = new Thread(s"session[$sessionId] process watcher thread") {
override def run(): Unit = {
try {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -116,10 +116,9 @@ class SparkSQLOperationManager private (name: String) extends OperationManager(n
new ExecuteScala(session, repl, statement, runAsync, queryTimeout, opHandle)
case OperationLanguages.PYTHON =>
try {
ExecutePython.init()
val worker = sessionToPythonProcess.getOrElseUpdate(
session.handle,
ExecutePython.createSessionPythonWorker(spark, session))
PythonWorkerPool.getPythonWorker(session))
new ExecutePython(session, statement, runAsync, queryTimeout, worker, opHandle)
} catch {
case e: Throwable =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3728,4 +3728,11 @@ object KyuubiConf {
.version("1.9.1")
.serverOnly
.fallbackConf(HIVE_SERVER2_THRIFT_RESULTSET_DEFAULT_FETCH_SIZE)

val PYTHON_WORKER_MAX_POOL_SIZE: ConfigEntry[Int] =
buildConf("python.worker.max.pool.size")
.doc("Set maximum quantity of python worker pool")
.version("1.10.0")
.intConf
.createWithDefault(1)
}
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,14 @@ class PySparkTests extends WithKyuubiServer with HiveJDBCTestHelper {
})
}

test("python worker process pooling") {
checkPythonRuntimeAndVersion()
val code = "print(1)"
val output = "1"
runPySparkTest(code, output)
runPySparkTest(code, output)
}

private def runPySparkTest(
pyCode: String,
output: String): Unit = {
Expand Down
Loading