-
Notifications
You must be signed in to change notification settings - Fork 28.9k
[SPARK-53785][SS] Memory Source for RTM #52502
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one or more | ||
* contributor license agreements. See the NOTICE file distributed with | ||
* this work for additional information regarding copyright ownership. | ||
* The ASF licenses this file to You under the Apache License, Version 2.0 | ||
* (the "License"); you may not use this file except in compliance with | ||
* the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package org.apache.spark.sql.execution.datasources.v2 | ||
|
||
import org.apache.spark.util.{Clock, SystemClock} | ||
|
||
/* The singleton object to control the time in testing */ | ||
object LowLatencyClock { | ||
private var clock: Clock = new SystemClock | ||
|
||
def getClock: Clock = clock | ||
|
||
def getTimeMillis(): Long = { | ||
clock.getTimeMillis() | ||
} | ||
|
||
def waitTillTime(targetTime: Long): Unit = { | ||
clock.waitTillTime(targetTime) | ||
} | ||
|
||
def setClock(inputClock: Clock): Unit = { | ||
clock = inputClock | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,300 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one or more | ||
* contributor license agreements. See the NOTICE file distributed with | ||
* this work for additional information regarding copyright ownership. | ||
* The ASF licenses this file to You under the Apache License, Version 2.0 | ||
* (the "License"); you may not use this file except in compliance with | ||
* the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package org.apache.spark.sql.execution.streaming | ||
|
||
import java.util.concurrent.atomic.AtomicInteger | ||
import javax.annotation.concurrent.GuardedBy | ||
|
||
import scala.collection.mutable.ListBuffer | ||
|
||
import org.json4s.{Formats, NoTypeHints} | ||
import org.json4s.jackson.Serialization | ||
|
||
import org.apache.spark.{SparkEnv, TaskContext} | ||
import org.apache.spark.rpc.{RpcAddress, RpcEndpointRef} | ||
import org.apache.spark.sql.{Encoder, SQLContext} | ||
import org.apache.spark.sql.catalyst.InternalRow | ||
import org.apache.spark.sql.catalyst.expressions.UnsafeRow | ||
import org.apache.spark.sql.connector.read.InputPartition | ||
import org.apache.spark.sql.connector.read.PartitionReader | ||
import org.apache.spark.sql.connector.read.PartitionReaderFactory | ||
import org.apache.spark.sql.connector.read.streaming.{ | ||
Offset => OffsetV2, | ||
PartitionOffset, | ||
ReadLimit, | ||
SupportsRealTimeMode, | ||
SupportsRealTimeRead | ||
} | ||
import org.apache.spark.sql.connector.read.streaming.SupportsRealTimeRead.RecordStatus | ||
import org.apache.spark.sql.execution.datasources.v2.LowLatencyClock | ||
import org.apache.spark.sql.execution.streaming.runtime._ | ||
import org.apache.spark.util.{Clock, RpcUtils} | ||
|
||
/** | ||
* A low latency memory source from memory, only for unit test purpose. | ||
* This class is very similar to ContinuousMemoryStream, except that it implements the | ||
* interface of SupportsRealTimeMode, rather than ContinuousStream | ||
* The overall strategy here is: | ||
* * LowLatencyMemoryStream maintains a list of records for each partition. addData() will | ||
* distribute records evenly-ish across partitions. | ||
* * RecordEndpoint is set up as an endpoint for executor-side | ||
* LowLatencyMemoryStreamInputPartitionReader instances to poll. It returns the record at | ||
* the specified offset within the list, or null if that offset doesn't yet have a record. | ||
* This differs from the existing memory source implementation as data is sent once to | ||
* tasks as part of the Partition/Split metadata at the beginning of a batch. | ||
*/ | ||
class LowLatencyMemoryStream[A: Encoder]( | ||
id: Int, | ||
sqlContext: SQLContext, | ||
numPartitions: Int = 2, | ||
clock: Clock = LowLatencyClock.getClock) | ||
extends MemoryStreamBaseClass[A](0, sqlContext) | ||
with SupportsRealTimeMode { | ||
private implicit val formats: Formats = Serialization.formats(NoTypeHints) | ||
|
||
@GuardedBy("this") | ||
private val records = Seq.fill(numPartitions)(new ListBuffer[UnsafeRow]) | ||
|
||
private val recordEndpoint = new ContinuousRecordEndpoint(records, this) | ||
@volatile private var endpointRef: RpcEndpointRef = _ | ||
|
||
override def addData(data: IterableOnce[A]): Offset = synchronized { | ||
// Distribute data evenly among partition lists. | ||
data.iterator.to(Seq).zipWithIndex.map { | ||
case (item, index) => | ||
records(index % numPartitions) += toRow(item).copy().asInstanceOf[UnsafeRow] | ||
} | ||
|
||
// The new target offset is the offset where all records in all partitions have been processed. | ||
LowLatencyMemoryStreamOffset((0 until numPartitions).map(i => (i, records(i).size)).toMap) | ||
} | ||
|
||
def addData(partitionId: Int, data: IterableOnce[A]): Offset = synchronized { | ||
require( | ||
partitionId >= 0 && partitionId < numPartitions, | ||
s"Partition ID $partitionId is out of bounds for $numPartitions partitions." | ||
) | ||
|
||
// Add data to the specified partition. | ||
records(partitionId) ++= data.iterator.map(item => toRow(item).copy().asInstanceOf[UnsafeRow]) | ||
|
||
// The new target offset is the offset where all records in all partitions have been processed. | ||
LowLatencyMemoryStreamOffset((0 until numPartitions).map(i => (i, records(i).size)).toMap) | ||
} | ||
|
||
override def initialOffset(): OffsetV2 = { | ||
LowLatencyMemoryStreamOffset((0 until numPartitions).map(i => (i, 0)).toMap) | ||
} | ||
|
||
override def latestOffset(startOffset: OffsetV2, limit: ReadLimit): OffsetV2 = synchronized { | ||
LowLatencyMemoryStreamOffset((0 until numPartitions).map(i => (i, records(i).size)).toMap) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we also need to add There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is an interesting point. For RTM, offset returned from latestOffset is actually not used. The offset returned from latestOffset defines the end offset of a batch for non-rtm streaming queries. In RTM, the end offset of a batch is calculated when the batch finishes. However, this source also support non-RTM queries. Though in streaming we typically use StreamTest framework that executes test actions and batches in synchronized steps so any race should not happen. Though for best practices I will add synchronized to the method. |
||
} | ||
|
||
override def deserializeOffset(json: String): LowLatencyMemoryStreamOffset = { | ||
LowLatencyMemoryStreamOffset(Serialization.read[Map[Int, Int]](json)) | ||
} | ||
|
||
override def mergeOffsets(offsets: Array[PartitionOffset]): LowLatencyMemoryStreamOffset = { | ||
LowLatencyMemoryStreamOffset( | ||
offsets.map { | ||
case ContinuousRecordPartitionOffset(part, num) => (part, num) | ||
}.toMap | ||
) | ||
} | ||
|
||
override def planInputPartitions(start: OffsetV2): Array[InputPartition] = { | ||
val startOffset = start.asInstanceOf[LowLatencyMemoryStreamOffset] | ||
synchronized { | ||
val endpointName = s"LowLatencyRecordEndpoint-${java.util.UUID.randomUUID()}-$id" | ||
endpointRef = recordEndpoint.rpcEnv.setupEndpoint(endpointName, recordEndpoint) | ||
|
||
startOffset.partitionNums.map { | ||
case (part, index) => | ||
LowLatencyMemoryStreamInputPartition( | ||
endpointName, | ||
endpointRef.address, | ||
part, | ||
index, | ||
Int.MaxValue | ||
) | ||
}.toArray | ||
} | ||
} | ||
|
||
override def planInputPartitions(start: OffsetV2, end: OffsetV2): Array[InputPartition] = { | ||
val startOffset = start.asInstanceOf[LowLatencyMemoryStreamOffset] | ||
val endOffset = end.asInstanceOf[LowLatencyMemoryStreamOffset] | ||
synchronized { | ||
val endpointName = s"LowLatencyRecordEndpoint-${java.util.UUID.randomUUID()}-$id" | ||
endpointRef = recordEndpoint.rpcEnv.setupEndpoint(endpointName, recordEndpoint) | ||
|
||
startOffset.partitionNums.map { | ||
case (part, index) => | ||
LowLatencyMemoryStreamInputPartition( | ||
endpointName, | ||
endpointRef.address, | ||
part, | ||
index, | ||
endOffset.partitionNums(part) | ||
) | ||
}.toArray | ||
} | ||
} | ||
|
||
override def createReaderFactory(): PartitionReaderFactory = { | ||
new LowLatencyMemoryStreamReaderFactory(clock) | ||
} | ||
|
||
override def stop(): Unit = { | ||
if (endpointRef != null) recordEndpoint.rpcEnv.stop(endpointRef) | ||
} | ||
|
||
override def commit(end: OffsetV2): Unit = {} | ||
|
||
override def reset(): Unit = synchronized { | ||
super.reset() | ||
records.foreach(_.clear()) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Need There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. will add |
||
} | ||
} | ||
|
||
object LowLatencyMemoryStream { | ||
protected val memoryStreamId = new AtomicInteger(0) | ||
|
||
def apply[A: Encoder](implicit sqlContext: SQLContext): LowLatencyMemoryStream[A] = | ||
new LowLatencyMemoryStream[A](memoryStreamId.getAndIncrement(), sqlContext) | ||
|
||
def apply[A: Encoder](numPartitions: Int)( | ||
implicit | ||
sqlContext: SQLContext): LowLatencyMemoryStream[A] = | ||
new LowLatencyMemoryStream[A]( | ||
memoryStreamId.getAndIncrement(), | ||
sqlContext, | ||
numPartitions = numPartitions | ||
) | ||
|
||
def singlePartition[A: Encoder](implicit sqlContext: SQLContext): LowLatencyMemoryStream[A] = | ||
new LowLatencyMemoryStream[A](memoryStreamId.getAndIncrement(), sqlContext, 1) | ||
} | ||
|
||
/** | ||
* An input partition for LowLatency memory stream. | ||
*/ | ||
case class LowLatencyMemoryStreamInputPartition( | ||
driverEndpointName: String, | ||
driverEndpointAddress: RpcAddress, | ||
partition: Int, | ||
startOffset: Int, | ||
endOffset: Int) | ||
extends InputPartition | ||
|
||
class LowLatencyMemoryStreamReaderFactory(clock: Clock) extends PartitionReaderFactory { | ||
override def createReader(partition: InputPartition): PartitionReader[InternalRow] = { | ||
val p = partition.asInstanceOf[LowLatencyMemoryStreamInputPartition] | ||
new LowLatencyMemoryStreamPartitionReader( | ||
p.driverEndpointName, | ||
p.driverEndpointAddress, | ||
p.partition, | ||
p.startOffset, | ||
p.endOffset, | ||
clock | ||
) | ||
} | ||
} | ||
|
||
/** | ||
* An input partition reader for LowLatency memory stream. | ||
* | ||
* Polls the driver endpoint for new records. | ||
*/ | ||
class LowLatencyMemoryStreamPartitionReader( | ||
driverEndpointName: String, | ||
driverEndpointAddress: RpcAddress, | ||
partition: Int, | ||
startOffset: Int, | ||
endOffset: Int, | ||
clock: Clock) | ||
extends SupportsRealTimeRead[InternalRow] { | ||
// Avoid tracking the ref, given that we create a new one for each partition reader | ||
// because a new driver endpoint is created for each LowLatencyMemoryStream. If we track the ref, | ||
// we can end up with a lot of refs (1000s) if a test suite has so many test cases and can lead to | ||
// issues with the tracking array. Causing the test suite to be flaky. | ||
private val endpoint = RpcUtils.makeDriverRef( | ||
driverEndpointName, | ||
driverEndpointAddress.host, | ||
driverEndpointAddress.port, | ||
SparkEnv.get.rpcEnv | ||
) | ||
|
||
private var currentOffset = startOffset | ||
private var current: Option[InternalRow] = None | ||
|
||
// Defense-in-depth against failing to propagate the task context. Since it's not inheritable, | ||
// we have to do a bit of error prone work to get it into every thread used by LowLatency | ||
// processing. We hope that some unit test will end up instantiating a LowLatency memory stream | ||
// in such cases. | ||
if (TaskContext.get() == null) { | ||
throw new IllegalStateException("Task context was not set!") | ||
} | ||
override def nextWithTimeout(timeout: java.lang.Long): RecordStatus = { | ||
val startReadTime = clock.nanoTime() | ||
var elapsedTimeMs = 0L | ||
current = getRecord | ||
while (current.isEmpty) { | ||
val POLL_TIME = 10L | ||
if (elapsedTimeMs >= timeout) { | ||
return RecordStatus.newStatusWithoutArrivalTime(false) | ||
} | ||
Thread.sleep(POLL_TIME) | ||
current = getRecord | ||
elapsedTimeMs = (clock.nanoTime() - startReadTime) / 1000 / 1000 | ||
} | ||
currentOffset += 1 | ||
RecordStatus.newStatusWithoutArrivalTime(true) | ||
} | ||
|
||
override def next(): Boolean = { | ||
current = getRecord | ||
if (current.isDefined) { | ||
currentOffset += 1 | ||
true | ||
} else { | ||
false | ||
} | ||
} | ||
|
||
override def get(): InternalRow = current.get | ||
|
||
override def close(): Unit = {} | ||
|
||
override def getOffset: ContinuousRecordPartitionOffset = | ||
ContinuousRecordPartitionOffset(partition, currentOffset) | ||
|
||
private def getRecord: Option[InternalRow] = { | ||
if (currentOffset >= endOffset) { | ||
return None | ||
} | ||
endpoint.askSync[Option[InternalRow]]( | ||
GetRecord(ContinuousRecordPartitionOffset(partition, currentOffset)) | ||
) | ||
} | ||
} | ||
|
||
case class LowLatencyMemoryStreamOffset(partitionNums: Map[Int, Int]) extends Offset { | ||
private implicit val formats: Formats = Serialization.formats(NoTypeHints) | ||
override def json(): String = Serialization.write(partitionNums) | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Planning to add additional code to this file i.e. the actual implementation of RealTimeStreamScanExec. Adding LowLatencyClock here since it is used by LowLatencyMemorySource. Trying to keep the PRs small :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you mean you will add the code of RealTimeStreamScanExec in other PR?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is the follow PR to provide more context:
#52620