Created
June 21, 2015 19:47
-
-
Save searler/124ba55566e4cbc3c0cd to your computer and use it in GitHub Desktop.
Akka reactive streams Bidi example from documentation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
/** | |
* Extracted from the Bidi example at http://doc.akka.io/docs/akka-stream-and-http-experimental/1.0-RC3/scala/stream-graphs.html | |
*/ | |
package bidi | |
import akka.util.ByteString | |
import akka.stream.scaladsl.BidiFlow | |
import akka.stream.scaladsl.Flow | |
import akka.stream.BidiShape | |
import java.nio.ByteOrder | |
import akka.stream.stage.Context | |
import akka.stream.stage.PushPullStage | |
import akka.stream.stage.SyncDirective | |
import akka.stream.scaladsl.Source | |
import akka.stream.scaladsl.Sink | |
import scala.concurrent.Await | |
import akka.stream.ActorFlowMaterializer | |
import akka.actor.ActorSystem | |
object ProtocolStacker extends App { | |
trait Message | |
case class Ping(id: Int) extends Message | |
case class Pong(id: Int) extends Message | |
def toBytes(msg: Message): ByteString = { | |
implicit val order = ByteOrder.LITTLE_ENDIAN | |
msg match { | |
case Ping(id) => ByteString.newBuilder.putByte(1).putInt(id).result() | |
case Pong(id) => ByteString.newBuilder.putByte(2).putInt(id).result() | |
} | |
} | |
def fromBytes(bytes: ByteString): Message = { | |
implicit val order = ByteOrder.LITTLE_ENDIAN | |
val it = bytes.iterator | |
it.getByte match { | |
case 1 => Ping(it.getInt) | |
case 2 => Pong(it.getInt) | |
case other => throw new RuntimeException(s"parse error: expected 1|2 got $other") | |
} | |
} | |
val codecVerbose = BidiFlow() { b => | |
// construct and add the top flow, going outbound | |
val outbound = b.add(Flow[Message].map(toBytes)) | |
// construct and add the bottom flow, going inbound | |
val inbound = b.add(Flow[ByteString].map(fromBytes)) | |
// fuse them together into a BidiShape | |
BidiShape(outbound, inbound) | |
} | |
// this is the same as the above | |
val codec = BidiFlow(toBytes _, fromBytes _) | |
val framing = BidiFlow() { b => | |
implicit val order = ByteOrder.LITTLE_ENDIAN | |
def addLengthHeader(bytes: ByteString) = { | |
val len = bytes.length | |
ByteString.newBuilder.putInt(len).append(bytes).result() | |
} | |
class FrameParser extends PushPullStage[ByteString, ByteString] { | |
// this holds the received but not yet parsed bytes | |
var stash = ByteString.empty | |
// this holds the current message length or -1 if at a boundary | |
var needed = -1 | |
override def onPush(bytes: ByteString, ctx: Context[ByteString]) = { | |
stash ++= bytes | |
run(ctx) | |
} | |
override def onPull(ctx: Context[ByteString]) = run(ctx) | |
override def onUpstreamFinish(ctx: Context[ByteString]) = | |
if (stash.isEmpty) ctx.finish() | |
else ctx.absorbTermination() // we still have bytes to emit | |
private def run(ctx: Context[ByteString]): SyncDirective = | |
if (needed == -1) { | |
// are we at a boundary? then figure out next length | |
if (stash.length < 4) pullOrFinish(ctx) | |
else { | |
needed = stash.iterator.getInt | |
stash = stash.drop(4) | |
run(ctx) // cycle back to possibly already emit the next chunk | |
} | |
} else if (stash.length < needed) { | |
// we are in the middle of a message, need more bytes | |
pullOrFinish(ctx) | |
} else { | |
// we have enough to emit at least one message, so do it | |
val emit = stash.take(needed) | |
stash = stash.drop(needed) | |
needed = -1 | |
ctx.push(emit) | |
} | |
/* | |
* After having called absorbTermination() we cannot pull any more, so if we need | |
* more data we will just have to give up. | |
*/ | |
private def pullOrFinish(ctx: Context[ByteString]) = | |
if (ctx.isFinishing) ctx.finish() | |
else ctx.pull() | |
} | |
val outbound = b.add(Flow[ByteString].map(addLengthHeader)) | |
val inbound = b.add(Flow[ByteString].transform(() => new FrameParser)) | |
BidiShape(outbound, inbound) | |
} | |
//--------------------------------- | |
implicit val system = ActorSystem() | |
implicit val materializer = ActorFlowMaterializer() | |
import scala.concurrent.duration._ | |
val stack = codec.atop(framing) | |
// test it by plugging it into its own inverse and closing the right end | |
val pingpong = Flow[Message].collect { case Ping(id) => Pong(id) } | |
val flow = stack.atop(stack.reversed).join(pingpong) | |
val result = Source((0 to 9).map(Ping)).via(flow).grouped(20).runWith(Sink.head) | |
println(Await.result(result, 1.second)) // should ===((0 to 9).map(Pong)) | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment