Spark Streaming 反压限流实现源码流程

1. RateController

RateController extend StreamingListener

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
// 这个注册在StreamingListenerBus监听器中 
// StreamingListenerBus extend SparkListener with ListenerBus[StreamingListener, StreamingListenerEvent]
protected override def doPostEvent(
listener: StreamingListener,
event: StreamingListenerEvent): Unit = {
event match {
case receiverStarted: StreamingListenerReceiverStarted =>
listener.onReceiverStarted(receiverStarted)
case receiverError: StreamingListenerReceiverError =>
listener.onReceiverError(receiverError)
case receiverStopped: StreamingListenerReceiverStopped =>
listener.onReceiverStopped(receiverStopped)
case batchSubmitted: StreamingListenerBatchSubmitted =>
listener.onBatchSubmitted(batchSubmitted)
case batchStarted: StreamingListenerBatchStarted =>
listener.onBatchStarted(batchStarted)
// 注册监听器
case batchCompleted: StreamingListenerBatchCompleted =>
listener.onBatchCompleted(batchCompleted)
case outputOperationStarted: StreamingListenerOutputOperationStarted =>
listener.onOutputOperationStarted(outputOperationStarted)
case outputOperationCompleted: StreamingListenerOutputOperationCompleted =>
listener.onOutputOperationCompleted(outputOperationCompleted)
case streamingStarted: StreamingListenerStreamingStarted =>
listener.onStreamingStarted(streamingStarted)
case _ =>
}
}

override def onBatchCompleted(batchCompleted: StreamingListenerBatchCompleted) {
val elements = batchCompleted.batchInfo.streamIdToInputInfo
for {
// 获取批次执行的基本信息,最后完成时间、处理延迟、调度延迟、以及StreamID与记录数的对应值
processingEnd <- batchCompleted.batchInfo.processingEndTime
workDelay <- batchCompleted.batchInfo.processingDelay
waitDelay <- batchCompleted.batchInfo.schedulingDelay
elems <- elements.get(streamUID).map(_.numRecords)
} computeAndPublish(processingEnd, elems, workDelay, waitDelay)
}

1
2
3
4
5
6
7
8
9
10
// 计算新的速率,由rateEstimator.compute计算
private def computeAndPublish(time: Long, elems: Long, workDelay: Long, waitDelay: Long): Unit =
Future[Unit] {
val newRate = rateEstimator.compute(time, elems, workDelay, waitDelay)
newRate.foreach { s =>
rateLimit.set(s.toLong)
publish(getLatestRate())
}
}

2. RateEstimator 的一个实现 PIDRateEstimator

PIDRateEstimator 实现了 RateEstimator 的 compute 方法,其计算过程主要运用到经典的工程学控制算法 PID,这是一种通过误差的比例、积分、微分共同作用的反馈控制算法,能够很好的通过误差反馈实现比较有利那个的目标量的调节。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
def compute(
time: Long, // in milliseconds
numElements: Long,
processingDelay: Long, // in milliseconds
schedulingDelay: Long // in milliseconds
): Option[Double] = {
logTrace(s"\ntime = $time, # records = $numElements, " +
s"processing time = $processingDelay, scheduling delay = $schedulingDelay")
this.synchronized {
if (time > latestTime && numElements > 0 && processingDelay > 0) {

// in seconds, should be close to batchDuration
val delaySinceUpdate = (time - latestTime).toDouble / 1000

// in elements/second
val processingRate = numElements.toDouble / processingDelay * 1000

// In our system `error` is the difference between the desired rate and the measured rate
// based on the latest batch information. We consider the desired rate to be latest rate,
// which is what this estimator calculated for the previous batch.
// in elements/second
val error = latestRate - processingRate

// The error integral, based on schedulingDelay as an indicator for accumulated errors.
// A scheduling delay s corresponds to s * processingRate overflowing elements. Those
// are elements that couldn't be processed in previous batches, leading to this delay.
// In the following, we assume the processingRate didn't change too much.
// From the number of overflowing elements we can calculate the rate at which they would be
// processed by dividing it by the batch interval. This rate is our "historical" error,
// or integral part, since if we subtracted this rate from the previous "calculated rate",
// there wouldn't have been any overflowing elements, and the scheduling delay would have
// been zero.
// (in elements/second)
val historicalError = schedulingDelay.toDouble * processingRate / batchIntervalMillis

// in elements/(second ^ 2)
val dError = (error - latestError) / delaySinceUpdate
// 这里是具体算法,将计算值和通过spark.streaming.backpressure.pid.minRate参数设置的值取大者作为下一次输入值
val newRate = (latestRate - proportional * error -
integral * historicalError -
derivative * dError).max(minRate)
logTrace(s"""
| latestRate = $latestRate, error = $error
| latestError = $latestError, historicalError = $historicalError
| delaySinceUpdate = $delaySinceUpdate, dError = $dError
""".stripMargin)

latestTime = time
if (firstRun) {
latestRate = processingRate
latestError = 0D
firstRun = false
logTrace("First run, rate estimation skipped")
None
} else {
latestRate = newRate
latestError = error
logTrace(s"New rate = $newRate")
Some(newRate)
}
} else {
logTrace("Rate estimation skipped")
None
}
}
}

3. PUSH

ReceiverInputDStream 是 RateController 的一个实现子类,实现了父类的 publish 方法

1
2
3
4
5
6
private[streaming] class ReceiverRateController(id: Int, estimator: RateEstimator)
extends RateController(id, estimator) {
override def publish(rate: Long): Unit =
ssc.scheduler.receiverTracker.sendRateUpdate(id, rate)
}

这是具体的send过程,endpoint实际为ReceiverTrackerEndPoint,这里能够看出用的是Spark的RPC通信机制,通过在RpcEnv中注册的ReceiverTrackerEndpoint将控制信息发给Receiver。

1
2
3
4
5
def sendRateUpdate(streamUID: Int, newRate: Long): Unit = synchronized {
if (isTrackerStarted) {
endpoint.send(UpdateReceiverRateLimit(streamUID, newRate))
}
}

然后再发送到Receiver注册的RpcEndpoint,用于接收Driver中ReceiverTracker发送过来的消息并处理。

1
2
3
4
case UpdateReceiverRateLimit(streamUID, newRate) =>
for (info <- receiverTrackingInfos.get(streamUID); eP <- info.endpoint) {
eP.send(UpdateRateLimit(newRate))
}

端点接收消息后匹配处理,在ReceiverSupervisorImpl中

1
2
3
4
5
case UpdateRateLimit(eps) =>
logInfo(s"Received a new rate limit: $eps.")
registeredBlockGenerators.asScala.foreach { bg =>
bg.updateRate(eps)
}

4. RateLimiter 具体速率限制的类

1
2
3
4
5
6
7
8
private[receiver] def updateRate(newRate: Long): Unit =
if (newRate > 0) {
if (maxRateLimit > 0) {
rateLimiter.setRate(newRate.min(maxRateLimit))
} else {
rateLimiter.setRate(newRate)
}
}

可以看到由Guava的实现库RateLimiter实现:

1
2
3
4
5
6
7
8
// treated as an upper limit
private val maxRateLimit = conf.getLong("spark.streaming.receiver.maxRate", Long.MaxValue)
private lazy val rateLimiter = GuavaRateLimiter.create(getInitialRateLimit().toDouble)

def waitToPush() {
// 获取许可,如获取不到则阻塞
rateLimiter.acquire()
}

使用RateLimiter的地方,BlockGenerator

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
/**
* Push a single data item into the buffer.
*/
def addData(data: Any): Unit = {
if (state == Active) {
// 这里就会限制速率
waitToPush()
synchronized {
if (state == Active) {
currentBuffer += data
} else {
throw new SparkException(
"Cannot add data as BlockGenerator has not been started or has been stopped")
}
}
} else {
throw new SparkException(
"Cannot add data as BlockGenerator has not been started or has been stopped")
}
}