Skip to content

Commit 5b3526f

Browse files
Clean shuffle data (#312) (#322)
* Cleanup Spark shuffle data after data is consumed * update comments --------- (cherry picked from commit 222e473) Signed-off-by: Peng Huo <penghuo@gmail.com> Signed-off-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com> Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
1 parent 6001731 commit 5b3526f

File tree

7 files changed

+112
-36
lines changed

7 files changed

+112
-36
lines changed

spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJob.scala

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -6,24 +6,15 @@
66
// defined in spark package so that I can use ThreadUtils
77
package org.apache.spark.sql
88

9-
import java.util.Locale
109
import java.util.concurrent.atomic.AtomicInteger
1110

12-
import org.opensearch.client.{RequestOptions, RestHighLevelClient}
13-
import org.opensearch.cluster.metadata.MappingMetadata
14-
import org.opensearch.common.settings.Settings
15-
import org.opensearch.common.xcontent.XContentType
16-
import org.opensearch.flint.core.{FlintClient, FlintClientBuilder, FlintOptions}
17-
import org.opensearch.flint.core.metadata.FlintMetadata
1811
import org.opensearch.flint.core.metrics.MetricConstants
1912
import org.opensearch.flint.core.metrics.MetricsUtil.registerGauge
2013
import play.api.libs.json._
2114

22-
import org.apache.spark.SparkConf
2315
import org.apache.spark.internal.Logging
24-
import org.apache.spark.sql.catalyst.parser.ParseException
2516
import org.apache.spark.sql.flint.config.FlintSparkConf
26-
import org.apache.spark.sql.types.{StructField, _}
17+
import org.apache.spark.sql.types._
2718

2819
/**
2920
* Spark SQL Application entrypoint

spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJobExecutor.scala

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -7,26 +7,18 @@ package org.apache.spark.sql
77

88
import java.util.Locale
99

10-
import scala.concurrent.{ExecutionContext, Future, TimeoutException}
11-
import scala.concurrent.duration.{Duration, MINUTES}
12-
1310
import com.amazonaws.services.s3.model.AmazonS3Exception
1411
import org.apache.commons.text.StringEscapeUtils.unescapeJava
15-
import org.opensearch.flint.core.{FlintClient, IRestHighLevelClient}
16-
import org.opensearch.flint.core.metadata.FlintMetadata
12+
import org.opensearch.flint.core.IRestHighLevelClient
1713
import org.opensearch.flint.core.metrics.MetricConstants
1814
import org.opensearch.flint.core.metrics.MetricsUtil.incrementCounter
19-
import play.api.libs.json.{JsArray, JsBoolean, JsObject, Json, JsString, JsValue}
15+
import play.api.libs.json._
2016

2117
import org.apache.spark.{SparkConf, SparkException}
2218
import org.apache.spark.internal.Logging
23-
import org.apache.spark.sql.FlintREPL.envinromentProvider
2419
import org.apache.spark.sql.catalyst.parser.ParseException
25-
import org.apache.spark.sql.execution.datasources.DataSource
26-
import org.apache.spark.sql.flint.config.FlintSparkConf
27-
import org.apache.spark.sql.types.{ArrayType, LongType, StringType, StructField, StructType}
28-
import org.apache.spark.sql.util.{DefaultThreadPoolFactory, EnvironmentProvider, RealEnvironment, RealTimeProvider, ThreadPoolFactory, TimeProvider}
29-
import org.apache.spark.util.ThreadUtils
20+
import org.apache.spark.sql.types._
21+
import org.apache.spark.sql.util._
3022

3123
trait FlintJobExecutor {
3224
this: Logging =>
@@ -157,7 +149,8 @@ trait FlintJobExecutor {
157149
query: String,
158150
sessionId: String,
159151
startTime: Long,
160-
timeProvider: TimeProvider): DataFrame = {
152+
timeProvider: TimeProvider,
153+
cleaner: Cleaner): DataFrame = {
161154
// Create the schema dataframe
162155
val schemaRows = result.schema.fields.map { field =>
163156
Row(field.name, field.dataType.typeName)
@@ -192,6 +185,11 @@ trait FlintJobExecutor {
192185
val resultSchemaToSave = resultSchema.toJSON.collect.toList.map(_.replaceAll("\"", "'"))
193186
val endTime = timeProvider.currentEpochMillis()
194187

188+
// https://github.com/opensearch-project/opensearch-spark/issues/302. Clean shuffle data
189+
// after consumed the query result. Streaming query shuffle data is cleaned after each
190+
// microBatch execution.
191+
cleaner.cleanUp(spark)
192+
195193
// Create the data rows
196194
val rows = Seq(
197195
(
@@ -375,7 +373,8 @@ trait FlintJobExecutor {
375373
query: String,
376374
dataSource: String,
377375
queryId: String,
378-
sessionId: String): DataFrame = {
376+
sessionId: String,
377+
streaming: Boolean): DataFrame = {
379378
// Execute SQL query
380379
val startTime = System.currentTimeMillis()
381380
// we have to set job group in the same thread that started the query according to spark doc
@@ -390,7 +389,8 @@ trait FlintJobExecutor {
390389
query,
391390
sessionId,
392391
startTime,
393-
currentTimeProvider)
392+
currentTimeProvider,
393+
CleanerFactory.cleaner(streaming))
394394
}
395395

396396
private def handleQueryException(

spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,15 +22,13 @@ import org.opensearch.flint.app.{FlintCommand, FlintInstance}
2222
import org.opensearch.flint.app.FlintInstance.formats
2323
import org.opensearch.flint.core.FlintOptions
2424
import org.opensearch.flint.core.metrics.MetricConstants
25-
import org.opensearch.flint.core.metrics.MetricsUtil.{decrementCounter, getTimerContext, incrementCounter, registerGauge, stopTimer}
25+
import org.opensearch.flint.core.metrics.MetricsUtil.{getTimerContext, incrementCounter, registerGauge, stopTimer}
2626
import org.opensearch.flint.core.storage.{FlintReader, OpenSearchUpdater}
2727
import org.opensearch.search.sort.SortOrder
2828

2929
import org.apache.spark.SparkConf
3030
import org.apache.spark.internal.Logging
31-
import org.apache.spark.sql.FlintJob.createSparkSession
3231
import org.apache.spark.sql.flint.config.FlintSparkConf
33-
import org.apache.spark.sql.flint.config.FlintSparkConf.REPL_INACTIVITY_TIMEOUT_MILLIS
3432
import org.apache.spark.sql.util.{DefaultShutdownHookManager, ShutdownHookManagerTrait}
3533
import org.apache.spark.util.ThreadUtils
3634

@@ -829,7 +827,13 @@ object FlintREPL extends Logging with FlintJobExecutor {
829827
startTime)
830828
} else {
831829
val futureQueryExecution = Future {
832-
executeQuery(spark, flintCommand.query, dataSource, flintCommand.queryId, sessionId)
830+
executeQuery(
831+
spark,
832+
flintCommand.query,
833+
dataSource,
834+
flintCommand.queryId,
835+
sessionId,
836+
false)
833837
}(executionContext)
834838

835839
// time out after 10 minutes

spark-sql-application/src/main/scala/org/apache/spark/sql/JobOperator.scala

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,10 @@ import scala.util.{Failure, Success, Try}
1414

1515
import org.opensearch.flint.core.metrics.MetricConstants
1616
import org.opensearch.flint.core.metrics.MetricsUtil.incrementCounter
17-
import org.opensearch.flint.core.storage.OpenSearchUpdater
1817

19-
import org.apache.spark.SparkConf
2018
import org.apache.spark.internal.Logging
21-
import org.apache.spark.sql.FlintJob.createSparkSession
22-
import org.apache.spark.sql.FlintREPL.{executeQuery, logInfo, threadPoolFactory, updateFlintInstanceBeforeShutdown}
2319
import org.apache.spark.sql.flint.config.FlintSparkConf
20+
import org.apache.spark.sql.util.ShuffleCleaner
2421
import org.apache.spark.util.ThreadUtils
2522

2623
case class JobOperator(
@@ -53,7 +50,7 @@ case class JobOperator(
5350
val futureMappingCheck = Future {
5451
checkAndCreateIndex(osClient, resultIndex)
5552
}
56-
val data = executeQuery(spark, query, dataSource, "", "")
53+
val data = executeQuery(spark, query, dataSource, "", "", streaming)
5754

5855
val mappingCheckResult = ThreadUtils.awaitResult(futureMappingCheck, Duration(1, MINUTES))
5956
dataToWrite = Some(mappingCheckResult match {
@@ -92,6 +89,8 @@ case class JobOperator(
9289
try {
9390
// Wait for streaming job complete if no error and there is streaming job running
9491
if (!exceptionThrown && streaming && spark.streams.active.nonEmpty) {
92+
// Clean Spark shuffle data after each microBatch.
93+
spark.streams.addListener(new ShuffleCleaner(spark))
9594
// wait if any child thread to finish before the main thread terminates
9695
spark.streams.awaitAnyTermination()
9796
}
@@ -149,4 +148,5 @@ case class JobOperator(
149148
case false => incrementCounter(MetricConstants.STREAMING_SUCCESS_METRIC)
150149
}
151150
}
151+
152152
}
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
6+
package org.apache.spark.sql.util
7+
8+
import org.apache.spark.{MapOutputTrackerMaster, SparkEnv}
9+
import org.apache.spark.internal.Logging
10+
import org.apache.spark.sql.SparkSession
11+
import org.apache.spark.sql.streaming.StreamingQueryListener
12+
13+
/**
14+
* Clean Spark shuffle data after each microBatch.
15+
* https://github.com/opensearch-project/opensearch-spark/issues/302
16+
*/
17+
class ShuffleCleaner(spark: SparkSession) extends StreamingQueryListener with Logging {
18+
19+
override def onQueryStarted(event: StreamingQueryListener.QueryStartedEvent): Unit = {}
20+
21+
override def onQueryProgress(event: StreamingQueryListener.QueryProgressEvent): Unit = {
22+
ShuffleCleaner.cleanUp(spark)
23+
}
24+
25+
override def onQueryTerminated(event: StreamingQueryListener.QueryTerminatedEvent): Unit = {}
26+
}
27+
28+
trait Cleaner {
29+
def cleanUp(spark: SparkSession)
30+
}
31+
32+
object CleanerFactory {
33+
def cleaner(streaming: Boolean): Cleaner = {
34+
if (streaming) NoOpCleaner else ShuffleCleaner
35+
}
36+
}
37+
38+
/**
39+
* No operation cleaner.
40+
*/
41+
object NoOpCleaner extends Cleaner {
42+
override def cleanUp(spark: SparkSession): Unit = {}
43+
}
44+
45+
/**
46+
* Spark shuffle data cleaner.
47+
*/
48+
object ShuffleCleaner extends Cleaner with Logging {
49+
def cleanUp(spark: SparkSession): Unit = {
50+
logInfo("Before cleanUp Shuffle")
51+
val cleaner = spark.sparkContext.cleaner
52+
val masterTracker = SparkEnv.get.mapOutputTracker.asInstanceOf[MapOutputTrackerMaster]
53+
val shuffleIds = masterTracker.shuffleStatuses.keys.toSet
54+
shuffleIds.foreach(shuffleId => cleaner.foreach(c => c.doCleanupShuffle(shuffleId, true)))
55+
logInfo("After cleanUp Shuffle")
56+
}
57+
}

spark-sql-application/src/test/scala/org/apache/spark/sql/FlintJobTest.scala

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ package org.apache.spark.sql
77

88
import org.apache.spark.SparkFunSuite
99
import org.apache.spark.sql.types._
10-
import org.apache.spark.sql.util.MockTimeProvider
10+
import org.apache.spark.sql.util.{CleanerFactory, MockTimeProvider}
1111

1212
class FlintJobTest extends SparkFunSuite with JobMatchers {
1313

@@ -76,7 +76,8 @@ class FlintJobTest extends SparkFunSuite with JobMatchers {
7676
"select 1",
7777
"20",
7878
currentTime - queryRunTime,
79-
new MockTimeProvider(currentTime))
79+
new MockTimeProvider(currentTime),
80+
CleanerFactory.cleaner(false))
8081
assertEqualDataframe(expected, result)
8182
}
8283

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
6+
package org.apache.spark.sql.util
7+
8+
import org.scalatest.matchers.should.Matchers
9+
10+
import org.apache.spark.SparkFunSuite
11+
12+
class CleanerFactoryTest extends SparkFunSuite with Matchers {
13+
14+
test("CleanerFactory should return NoOpCleaner when streaming is true") {
15+
val cleaner = CleanerFactory.cleaner(streaming = true)
16+
cleaner shouldBe NoOpCleaner
17+
}
18+
19+
test("CleanerFactory should return ShuffleCleaner when streaming is false") {
20+
val cleaner = CleanerFactory.cleaner(streaming = false)
21+
cleaner shouldBe ShuffleCleaner
22+
}
23+
}

0 commit comments

Comments
 (0)