Skip to content

Instantly share code, notes, and snippets.

@mariusae
Created January 28, 2015 20:15
Show Gist options
  • Save mariusae/0481a58683ca220a3ca8 to your computer and use it in GitHub Desktop.
Save mariusae/0481a58683ca220a3ca8 to your computer and use it in GitHub Desktop.
A demonstration of using Stack to implement per-endpoint behavior.
/*
* We're going to demonstrate how to use Finagle's Stack facilities to inject behavior
* deep down in Finagle's stack. Namely, we want to capture the address of a remote
* server in that server's response.
*
* While retrieving a client address is something that perhaps deserves a first-class
* API, it's nevertheless interesting to exercise the flexbility of the Stack mechanism.
*
* This code uses only public APIs.
*
* Our strategy is to create a modified ThriftMux client that embeds logic to record
* the client address and then modify the response to include that address.
*
* We're going to use the following Thrift IDL, a simple echo service.
*
* > namespace java org.monkey.recordaddr
* >
* > service RecordingService {
* > string echo(1: string arg);
* > }
*
*/
package org.monkey.recordaddr
import com.twitter.finagle._
import com.twitter.finagle.client.{StackClient, Transporter}
import com.twitter.io.Buf
import com.twitter.scrooge.TReusableMemoryTransport
import com.twitter.scrooge.ThriftStruct
import com.twitter.scrooge.ThriftStructCodec
import com.twitter.util.{Await, Future, NonFatal}
import java.net.SocketAddress
import java.util.Arrays
import org.apache.thrift.TApplicationException
import org.apache.thrift.protocol.{TBinaryProtocol, TMessage, TMessageType}
import org.apache.thrift.transport.TMemoryInputTransport
/*
* Object ThriftMuxRecorded is the home of our modified client.
*/
object ThriftMuxRecorded {
/*
* First, we need some utilities for mucking around with Thrift structures
* as generated by Scrooge. Scrooge generates a _codec_ for each of its
* types: a ThriftStructCodec[T] can create a T-typed object from a Thrift
* Protocol. Thrift Protocols embody a particular Thrift representation,
* like "binary", "compact", or "JSON". A ThriftStructCodec can use a low
* level interface to the representation to create a first-class instance of
* a structure.
*/
private def decodeResponse[T <: ThriftStruct](resBytes: Array[Byte], codec: ThriftStructCodec[T]): Option[T] = {
/*
* Our method takes a byte array and a codec. We assume we're using the
* binary protocol. (This is default in Finagle.) So, we'll instantiate a BinaryProtocol
* over a transport that simply represents the bytestream.
*/
val iprot = new TBinaryProtocol(new TMemoryInputTransport(resBytes))
/*
* Thrift requests and responses have a header that contains method names,
* result types (success vs. exception) and sequence IDs. We decode this to
* make sure it's a succesful response.
*
* We're going to pass exceptions through (by returning None); and otherwise
* attempt to decode the message. We take a conservative approach: if we fail
* to decode the message, we simply skip it.
*/
val msg = iprot.readMessageBegin()
msg.`type` match {
case TMessageType.EXCEPTION => None
case _ =>
try Some(codec.decode(iprot)) catch {
case NonFatal(_) => None
}
}
}
/*
* After we've modified the response, we'll need to encode it back to
* a byte array again. Here's a method for that.
*/
private def encodeResponse(name: String, seqid: Int, result: ThriftStruct): Array[Byte] = {
/*
* Since we want our output in a fully buffered byte array: our "transport" is an in-memory one.
*/
val memoryBuffer: TReusableMemoryTransport = TReusableMemoryTransport(512)
try {
/*
* As in decodeResponse, we instantiate a binary protocol and write out our
* message header: a standard, succesfull REPLY message with the given
* name (this is the method name).
*/
val oprot = new TBinaryProtocol(memoryBuffer)
oprot.writeMessageBegin(new TMessage(name, TMessageType.REPLY, seqid))
result.write(oprot)
oprot.writeMessageEnd()
Arrays.copyOfRange(memoryBuffer.getArray(), 0, memoryBuffer.length())
}
finally {
memoryBuffer.reset()
}
}
/*
* Finally, a utility to extract the method name out of a byte array. Our thrift
* service might have multiple methods, and we only want to apply our logic
* to a subset of them.
*/
private def methodName(bytes: Array[Byte]): String = {
val trans = new TMemoryInputTransport(bytes)
val prot = new TBinaryProtocol(trans)
prot.readMessageBegin().name
}
/*
* Now we're getting to the meat of our little utility. Recorder is a simple
* filter that is parameterized on a SocketAddress (that of our client) and
* embeds this address in replies to the "echo" method.
*/
class Recorder(addr: SocketAddress) extends SimpleFilter[mux.Request, mux.Response] {
/*
* We're going to pass the mux request through, and examine only the response.
* A Mux response is a single-method trait:
* > trait Response { body: Buf }
* Which is the payload of the response. Buf is Twitter util's abstraction for doing
* efficient byte buffering.
*/
def apply(req: mux.Request, service: Service[mux.Request, mux.Response]) = {
service(req) map { res =>
/*
* This is a bit of a mouthful: Buf.ByteArray.Owned.extract(res.body) says:
* coerce res.body (a Buf) into an "Owned" byte array. This means that the
* caller is responsible for not manipulating the byte array as it may be shared
* directly with other uses. Provided you don't manipulate the byte array, this
* is generally more efficient since we don't need to copy the payload for each use.
*/
methodName(Buf.ByteArray.Owned.extract(res.body)) match {
case "echo" =>
/*
* This invocation of decodeResponse tries to decode a RecordingService.echo$result out
* of the response payload. Thrift RPC requests and responses are standard Thrift structures.
* For each method, Scrooge will generate structures of the form {service}.{method}$arg and
* {service}.{method}.$result.
*/
decodeResponse(Buf.ByteArray.Owned.extract(res.body), RecordingService.echo$result) match {
case Some(result) =>
/*
* If we succesfully decoded, we're simply going to append the toString
* of the socket address into the result.
*/
val newResult = result.copy(success=result.success.map(_+addr.toString))
mux.Response(Buf.ByteArray.Owned(encodeResponse("echo", 0, newResult)))
case None => res
}
case _ => res
}
}
}
}
/*
* Alright. Now we have a filter that can perform the manipulation we're interested in.
* It's time to glue things together.
*
* Finagle clients and servers are constructed from a stack of modules. These stacks are
* then parameterized with an instance of Stack.Param. Stack.Param is type-keyed map.
* Finagle mints a new type (a case class) for each parameter that is used to configure
* the stack. For example, com.twitter.finagle.param.Stats is the key for the StatsReceiver
* that a Finagle client or server uses to report various metrics and gauges.
*
* The parameter we're interested in is com.twitter.finagle.client.transporter.EndpointAddr.
* This configures the concrete socket address an instance of the stack connects to.
*
* Stacks are first-class objects, which we can manipulate in interesting ways. For example,
* ThriftMux's stack looks like this:
*
* > scala> println(com.twitter.finagle.ThriftMux.client.stack)
* > Node(role = protocolrecorder, description = Record ThriftMux protocol usage)
* > Node(role = traceinitializerfilter, description = Initialize the tracing system)
* > Node(role = clienttracingfilter, description = Report finagle information and client send/recv events)
* > Node(role = prototracing, description = Mux specific clnt traces)
* > Node(role = requeueingfilter, description = Retry automatically on WriteExceptions)
* > Node(role = factorytoservice, description = Apply service factory on each service request)
* > Node(role = prepfactory, description = PrepFactory)
* > Node(role = servicecreationstats, description = Track statistics on service creation failures and service acquisition latency)
* > Node(role = servicetimeout, description = Time out service acquisition after a given period)
* > Node(role = requestdraining, description = RequestDraining)
* > Node(role = binding, description = Bind destination names to endpoints)
* > Node(role = namertracer, description = Trace the details of the Namer lookup)
* > Node(role = loadbalancer, description = Balance requests across multiple endpoints)
* > Node(role = exceptionsource, description = Source exceptions to the service name)
* > Node(role = monitoring, description = Act as last-resort exception handler)
* > Node(role = endpointtracing, description = Record remote address of server)
* > Node(role = dtabstats, description = Report dtab statistics)
* > Node(role = requeststats, description = Report request statistics)
* > Node(role = factorystats, description = Report connection statistics)
* > Node(role = failureaccrual, description = Backoff from hosts that we cannot successfully make requests to)
* > Node(role = requesttimeout, description = Apply a timeout to requests)
* > Node(role = singletonpool, description = Maintain at most one connection)
* > Node(role = failfast, description = Backoff exponentially from hosts to which we cannot establish a connection)
* > Node(role = expiration, description = Expire a service after a certain amount of idle time)
* > Node(role = prepconn, description = PrepConn)
* > Leaf(role = endpoint, description = endpoint)
*
* As you can see, each module in the stack (called Nodes here) serves a single, well-defined purpose.
* Requests flow from top to bottom. The role of a module is a unique identifier for the role that
* module serves; the description is a simple English description.
*
* We're now going to define a module that is parameterized on Transport.EndpointAddr,
* and then insert it into the right place the stack.
*
* Stack.Module1 is the type used for 1-ary modules. (Similar to Scala's Function1.)
* Since Stacks are general, they (and thus their modules) are also parameterized
* on the type they produce. In this case, a ServiceFactory[mux.Request, mux.Response].
*/
val recorderModule = new Stack.Module1[
Transporter.EndpointAddr,
ServiceFactory[mux.Request, mux.Response]] {
/*
* We define a new role for our module.
*/
val role = Stack.Role("recordechodest")
val description = "record destination information for RecordingService.echo requests"
/*
* Our module's job is very simple: we're going to extract the SocketAddress
* from our parameter, instantiate a Recorder filter with this address, and
* then apply it to the downstream ServiceFactory.
*/
def make(
_addr: Transporter.EndpointAddr,
next: ServiceFactory[mux.Request, mux.Response]
): ServiceFactory[mux.Request, mux.Response] = {
val Transporter.EndpointAddr(addr) = _addr
new Recorder(addr) andThen next
}
}
/*
* We define a general method to insert an element into a specific place in the Stack--
* before the module of another role. This should be part of the default methods
* provided by Stack, but it's not (yet), so we'll define it ourselves.
*
* Stacks are structured much like lists. A Stack.Node is akin to a Cons cell; Stack.Leaf to a Nil.
*/
def insert[T](stack: Stack[T], role: Stack.Role, module: Stackable[T]): Stack[T] = {
if (stack.head.role == role) {
/**
* We found the role that we're looking for, so prepend the module.
*/
module +: stack
} else {
stack match {
case Stack.Node(head, mk, next) =>
/*
* We found an interior stack node. We're going to keep it the same, but
* recursively apply insert on the
*/
Stack.Node(head, mk, insert(next, role, module))
case [email protected](_, _) =>
/* Leaves terminate the insertion. If we reach this point, we failed
* to insert the module. */
l
}
}
}
/*
* Finally we're going to create a new ThriftMux client that uses our custom stack.
* ThriftMux is a think layer on top of Mux that contains logic to translate Thrift
* messages to Bufs and vice-versa. ThriftMux also provides convenience methods
* for instantiating implementations generated by Scrooge or by the Finagle edition
* of apache thrift.
*
* A ThriftMux client is instantiated by providing a Mux stack; it composes on top.
* Here we simply take the default ThriftMux stack and insert our module before
* the prepConn role. If you recall, prepConn was at the end of our stack:
*
* ...
* > Node(role = requesttimeout, description = Apply a timeout to requests)
* > Node(role = singletonpool, description = Maintain at most one connection)
* > Node(role = failfast, description = Backoff exponentially from hosts to which we cannot establish a connection)
* > Node(role = expiration, description = Expire a service after a certain amount of idle time)
* > Node(role = prepconn, description = PrepConn)
* > Leaf(role = endpoint, description = endpoint)
*
* We need to place our module some time after the load balancer, since this is the
* module that's responsible for spreading load over a number of concrete clients.
* Thus the stack above the load balancer doesn't have access to the true address.
*/
val client = ThriftMux.Client(
Mux.client.copy(insert(ThriftMux.client.stack, StackClient.Role.prepConn, recorderModule)))
}
/*
* We define a simple test, showing the modified client working:
*
* $ src run
* ok/0.0.0.0:8000
*/
object Main {
def main(args: Array[String]) {
ThriftMux.serveIface(":8000", new RecordingService.FutureIface {
def echo(arg: String) = Future.value(arg)
})
val client = ThriftMuxRecorded.client.newIface[RecordingService.FutureIface]("0:8000")
println(Await.result(client.echo("ok")))
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment