Created
July 15, 2017 11:06
-
-
Save jroper/4f8108f8fa00a7251919e08d4cf9eb71 to your computer and use it in GitHub Desktop.
Akka streams LazyBroadcastHub - a broadcast hub that only keeps its source materialized as long as there are consumers
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package streams.utils | |
import akka.NotUsed | |
import akka.stream._ | |
import akka.stream.scaladsl.{BroadcastHub, Keep, RunnableGraph, Source} | |
import akka.stream.stage._ | |
import scala.concurrent.duration.{Duration, FiniteDuration} | |
/** | |
* Provides a broadcast hub that only runs the source when there are sinks connected to it. | |
*/ | |
object LazyBroadcastHub { | |
/** | |
* Create a broadcast hub for the given source. | |
* | |
* The hub will only run the source when there are consumers attached to the hub. When all consumers disconnect, | |
* after the given idle timeout, if no more consumers connect it will shut the source down. | |
* | |
* The source will be rematerialized whenever it's not running but a new consumer attaches to the hub. | |
* | |
* The materialization value is a tuple of a source as produced by BroadcastHub, and a KillSwitch to kill the hub. | |
* | |
* @param source The source to broadcast. | |
* @param idleTimeout The time to wait when there are no consumers before shutting the source down. | |
* @param bufferSize The buffer size to buffer messages to producers. | |
*/ | |
def forSource[T](source: Source[T, _], idleTimeout: FiniteDuration, bufferSize: Int): RunnableGraph[(Source[T, NotUsed], KillSwitch)] = { | |
Source.fromGraph(new LazySourceStage[T](source, idleTimeout)) | |
.viaMat(KillSwitches.single)(Keep.both) | |
.toMat(BroadcastHub.sink[T](bufferSize)) { | |
case ((callbacks, killSwitch), broadcastSource) => | |
val source = broadcastSource.via(new RecordingStage[T](callbacks)) | |
(source, killSwitch) | |
} | |
} | |
def forSource[T](source: Source[T, _], idleTimeout: FiniteDuration): RunnableGraph[(Source[T, NotUsed], KillSwitch)] = | |
forSource(source, idleTimeout, bufferSize = 256) | |
def forSource[T](source: Source[T, _], bufferSize: Int): RunnableGraph[(Source[T, NotUsed], KillSwitch)] = | |
forSource(source, Duration.Zero, bufferSize) | |
def forSource[T](source: Source[T, _]): RunnableGraph[(Source[T, NotUsed], KillSwitch)] = | |
forSource(source, Duration.Zero) | |
private trait MaterializationCallbacks { | |
def materialized(): Unit | |
def completed(): Unit | |
} | |
private class RecordingStage[T](callbacks: MaterializationCallbacks) extends GraphStage[FlowShape[T, T]] { | |
private val in = Inlet[T]("RecordingStage.in") | |
private val out = Outlet[T]("RecordingStage.out") | |
override def shape = FlowShape(in, out) | |
override def createLogic(inheritedAttributes: Attributes) = new GraphStageLogic(shape) { | |
setHandler(in, new InHandler { | |
override def onPush() = push(out, grab(in)) | |
}) | |
setHandler(out, new OutHandler { | |
override def onPull() = pull(in) | |
override def onDownstreamFinish() = callbacks.completed() | |
}) | |
override def preStart() = { | |
// This must be done in preStart, if done during materialization then there's a race for the LazySourceStage | |
// to finish materializing before this gets invoked. | |
callbacks.materialized() | |
} | |
} | |
} | |
private class LazySourceStage[T](source: Source[T, _], idleTimeout: FiniteDuration) extends GraphStageWithMaterializedValue[SourceShape[T], MaterializationCallbacks] { | |
private val out = Outlet[T]("LazySourceStage.out") | |
override def shape = SourceShape(out) | |
override def createLogicAndMaterializedValue(inheritedAttributes: Attributes) = { | |
val logic = new TimerGraphStageLogic(shape) with MaterializationCallbacks { | |
var materializedSources = 0 | |
var activeIn: Option[SubSinkInlet[T]] = None | |
var stopSourceRequest = 0 | |
val materializedCallback = createAsyncCallback[Unit] { _ => | |
materializedSources += 1 | |
if (activeIn.isEmpty) { | |
startSource() | |
} | |
} | |
val completedCallback = createAsyncCallback[Unit] { _ => | |
materializedSources -= 1 | |
if (materializedSources == 0) { | |
if (idleTimeout == Duration.Zero) { | |
stopSource() | |
} else { | |
stopSourceRequest += 1 | |
scheduleOnce(stopSourceRequest, idleTimeout) | |
} | |
} | |
} | |
def startSource() = { | |
assert(activeIn.isEmpty) | |
val in = new SubSinkInlet[T]("LazySourceStage.in") | |
in.setHandler(new InHandler { | |
override def onPush() = push(out, in.grab()) | |
}) | |
setHandler(out, new OutHandler { | |
override def onPull() = in.pull() | |
}) | |
source.runWith(in.sink)(subFusingMaterializer) | |
if (isAvailable(out)) { | |
in.pull() | |
} | |
activeIn = Some(in) | |
} | |
def stopSource() = { | |
assert(activeIn.nonEmpty) | |
activeIn.get.cancel() | |
ignoreOut() | |
activeIn = None | |
} | |
override protected def onTimer(timerKey: Any) = { | |
if (stopSourceRequest == timerKey && materializedSources == 0) { | |
stopSource() | |
} | |
} | |
def ignoreOut() = { | |
setHandler(out, new OutHandler { | |
override def onPull() = () | |
}) | |
} | |
override def materialized() = { | |
materializedCallback.invoke(()) | |
} | |
override def completed() = { | |
completedCallback.invoke(()) | |
} | |
ignoreOut() | |
} | |
(logic, logic) | |
} | |
} | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package streams.utils | |
import java.util.concurrent.atomic.AtomicBoolean | |
import akka.Done | |
import akka.actor.ActorSystem | |
import akka.stream.scaladsl.{Sink, Source} | |
import akka.stream.testkit.javadsl.TestSink | |
import akka.stream.{ActorMaterializer, Materializer} | |
import org.scalatest.{BeforeAndAfterAll, Matchers, WordSpec} | |
import play.api.{Environment, LoggerConfigurator} | |
import streams.utils.LazyBroadcastHub | |
import scala.concurrent.{Await, Promise} | |
import scala.concurrent.duration._ | |
class LazyBroadcastHubSpec extends WordSpec with Matchers with BeforeAndAfterAll { | |
implicit var system: ActorSystem = _ | |
implicit var materializer: Materializer = _ | |
implicit def dispatcher = system.dispatcher | |
val environment = Environment.simple() | |
LoggerConfigurator(environment.classLoader).foreach(_.configure(environment)) | |
"LazyBroadcastHub" should { | |
"not start the source if there are no consumers" in { | |
val materialized = new AtomicBoolean() | |
LazyBroadcastHub.forSource(Source.empty.mapMaterializedValue(_ => materialized.set(true))).run() | |
Thread.sleep(200) | |
materialized.get() should be (false) | |
} | |
"start the source when a consumer attaches" in { | |
val (source, _) = LazyBroadcastHub.forSource(Source.repeat("a")).run() | |
val sink = source.runWith(TestSink.probe(system)) | |
sink.requestNext("a") | |
} | |
"shut down the source when a single consumer disconnects" in { | |
val shutdown = Promise[Done]() | |
val (source, _) = LazyBroadcastHub.forSource(Source.repeat("a").watchTermination() { (_, term) => | |
shutdown.completeWith(term) | |
}).run() | |
source.runWith(Sink.head) | |
Await.ready(shutdown.future, 10.seconds) | |
} | |
"not shutdown when there is still a consumer" in { | |
val shutdown = Promise[Done]() | |
val (source, _) = LazyBroadcastHub.forSource(Source.repeat("a").watchTermination() { (_, term) => | |
shutdown.completeWith(term) | |
}).run() | |
val sink1 = source.runWith(TestSink.probe(system)) | |
val sink2 = source.runWith(TestSink.probe(system)) | |
sink1.requestNext("a") | |
sink2.requestNext("a") | |
sink2.cancel() | |
Thread.sleep(200) | |
shutdown.isCompleted should be (false) | |
} | |
"shut down when multiple consumers disconnect" in { | |
val shutdown = Promise[Done]() | |
val (source, _) = LazyBroadcastHub.forSource(Source.repeat("a").watchTermination() { (_, term) => | |
shutdown.completeWith(term) | |
}).run() | |
val sink1 = source.runWith(TestSink.probe(system)) | |
val sink2 = source.runWith(TestSink.probe(system)) | |
sink1.requestNext("a") | |
sink2.requestNext("a") | |
sink1.cancel() | |
sink2.cancel() | |
Await.ready(shutdown.future, 10.seconds) | |
} | |
"wait until a timeout before disconnecting" in { | |
val shutdown = Promise[Done]() | |
val (source, _) = LazyBroadcastHub.forSource(Source.repeat("a").watchTermination() { (_, term) => | |
shutdown.completeWith(term) | |
}, 300.millis).run() | |
source.runWith(Sink.head) | |
Thread.sleep(200) | |
shutdown.isCompleted should be (false) | |
Await.ready(shutdown.future, 10.seconds) | |
} | |
"not disconnect if a new sink connects within the timeout" in { | |
val shutdown = Promise[Done]() | |
val (source, _) = LazyBroadcastHub.forSource(Source.repeat("a").watchTermination() { (_, term) => | |
shutdown.completeWith(term) | |
}, 300.millis).run() | |
source.runWith(Sink.head) | |
Thread.sleep(200) | |
val sink = source.runWith(TestSink.probe(system)) | |
sink.requestNext("a") | |
Thread.sleep(200) | |
shutdown.isCompleted should be (false) | |
} | |
} | |
override protected def beforeAll() = { | |
system = ActorSystem("Test") | |
materializer = ActorMaterializer() | |
} | |
override protected def afterAll() = { | |
system.terminate() | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment