Last active
March 12, 2018 14:22
-
-
Save jroper/f28ecf79f4a4be70e3f499a672d8d6b5 to your computer and use it in GitHub Desktop.
Akka streams Source.restartWithBackoff
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package streams.utils | |
import java.util.concurrent.ThreadLocalRandom | |
import akka.NotUsed | |
import akka.stream._ | |
import akka.stream.scaladsl.Source | |
import akka.stream.stage._ | |
import scala.concurrent.duration._ | |
object SourceWithBackoffSupervision { | |
implicit class EnrichedSupervisableSource[T](source: Source[T, _]) { | |
/** | |
* Restart the source with the given backoff parameters when it completes or fails. | |
* | |
* @param minBackoff The minimum backoff. | |
* @param maxBackoff The maximum backoff. | |
* @param randomFactor A random factor. | |
*/ | |
def restartWithBackoff( | |
minBackoff: FiniteDuration, | |
maxBackoff: FiniteDuration, | |
randomFactor: Double | |
): Source[T, NotUsed] = { | |
Source.fromGraph(new RestartWithBackoff[T](minBackoff, maxBackoff, randomFactor, source)) | |
} | |
} | |
private final class RestartWithBackoff[T]( | |
minBackoff: FiniteDuration, | |
maxBackoff: FiniteDuration, | |
randomFactor: Double, | |
thisSource: Graph[SourceShape[T], _] | |
) extends GraphStage[SourceShape[T]] { | |
private val out = Outlet[T]("RestartWithBackoff.out") | |
override def shape = SourceShape(out) | |
override def initialAttributes = Attributes.name("RestartWithBackoff") | |
override def createLogic(attr: Attributes) = new TimerGraphStageLogicWithLogging(shape) { | |
var restartCount = 0 | |
var resetDeadline = minBackoff.fromNow | |
def startSource(): Unit = { | |
val sinkIn = new SubSinkInlet[T]("RestartWithBackoffSink") | |
sinkIn.setHandler(new InHandler { | |
override def onPush(): Unit = push(out, sinkIn.grab()) | |
override def onUpstreamFinish() = { | |
log.debug("Source finished") | |
onCompleteOrFailure() | |
} | |
override def onUpstreamFailure(ex: Throwable) = { | |
log.error(ex, "Restarting source due to failure") | |
onCompleteOrFailure() | |
} | |
}) | |
setHandler(out, new OutHandler { | |
override def onPull(): Unit = sinkIn.pull() | |
override def onDownstreamFinish() = sinkIn.cancel() | |
}) | |
Source.fromGraph(thisSource).runWith(sinkIn.sink)(subFusingMaterializer) | |
if (isAvailable(out)) sinkIn.pull() | |
} | |
def backoff(): Unit = { | |
setHandler(out, new OutHandler { | |
override def onPull() = () | |
}) | |
} | |
def onCompleteOrFailure()= { | |
if (resetDeadline.isOverdue()) { | |
restartCount = 0 | |
} | |
val restartDelay = calculateDelay(restartCount, minBackoff, maxBackoff, randomFactor) | |
log.debug("Restarting stream in {}", restartDelay) | |
scheduleOnce("RestartTimer", restartDelay) | |
restartCount += 1 | |
backoff() | |
} | |
override protected def onTimer(timerKey: Any) = { | |
startSource() | |
resetDeadline = minBackoff.fromNow | |
} | |
setHandler(out, new OutHandler { | |
override def onPull() = startSource() | |
}) | |
} | |
override def toString: String = "RestartWithBackoff" | |
} | |
/** | |
* Copied from akka.pattern.BackoffSupervisor. | |
*/ | |
private def calculateDelay( | |
restartCount: Int, | |
minBackoff: FiniteDuration, | |
maxBackoff: FiniteDuration, | |
randomFactor: Double): FiniteDuration = { | |
val rnd = 1.0 + ThreadLocalRandom.current().nextDouble() * randomFactor | |
if (restartCount >= 30) // Duration overflow protection (> 100 years) | |
maxBackoff | |
else | |
maxBackoff.min(minBackoff * math.pow(2, restartCount)) * rnd match { | |
case f: FiniteDuration ⇒ f | |
case _ ⇒ maxBackoff | |
} | |
} | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package stream.utils | |
import akka.actor.ActorSystem | |
import akka.stream.{ActorMaterializer, Materializer} | |
import akka.stream.scaladsl.Source | |
import akka.stream.testkit.scaladsl.TestSink | |
import org.scalatest.{BeforeAndAfterAll, Matchers, WordSpec} | |
import scala.concurrent.duration._ | |
class SourceWithBackoffSupervisionSpec extends WordSpec with Matchers with BeforeAndAfterAll { | |
import streams.utils.SourceWithBackoffSupervision._ | |
implicit var system: ActorSystem = _ | |
implicit var materializer: Materializer = _ | |
"RestartWithBackoff source" should { | |
"run normally" in { | |
val probe = Source.repeat("a") | |
.restartWithBackoff(500.millis, 1.seconds, 0) | |
.runWith(TestSink.probe) | |
probe.requestNext("a") | |
probe.requestNext("a") | |
probe.requestNext("a") | |
probe.requestNext("a") | |
probe.requestNext("a") | |
probe.cancel() | |
} | |
"restart on completion" in { | |
val probe = Source(List("a", "b")) | |
.restartWithBackoff(10.millis, 100.millis, 0) | |
.runWith(TestSink.probe) | |
probe.requestNext("a") | |
probe.requestNext("b") | |
probe.requestNext("a") | |
probe.requestNext("b") | |
probe.requestNext("a") | |
probe.requestNext("b") | |
probe.cancel() | |
} | |
"restart on failure" in { | |
val probe = Source(List("a", "b", "c")) | |
.map { | |
case "c" => sys.error("failed") | |
case other => other | |
} | |
.restartWithBackoff(10.millis, 100.millis, 0) | |
.runWith(TestSink.probe) | |
probe.requestNext("a") | |
probe.requestNext("b") | |
probe.requestNext("a") | |
probe.requestNext("b") | |
probe.requestNext("a") | |
probe.requestNext("b") | |
probe.cancel() | |
} | |
"backoff before restart" in { | |
val probe = Source(List("a", "b")) | |
.restartWithBackoff(1.second, 2.seconds, 0) | |
.runWith(TestSink.probe) | |
probe.requestNext("a") | |
probe.requestNext("b") | |
probe.request(1) | |
probe.expectNoMsg(500.milliseconds) | |
probe.expectNext(1.second, "a") | |
probe.requestNext("b") | |
probe.cancel() | |
} | |
} | |
override protected def beforeAll() = { | |
system = ActorSystem("Test") | |
materializer = ActorMaterializer() | |
} | |
override protected def afterAll() = super.afterAll() | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment