Created
October 2, 2016 18:56
-
-
Save Baccata/52dd5a45e99c1c6ae1f8f374afbe29ce to your computer and use it in GitHub Desktop.
monix-grpc
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import com.google.common.util.concurrent.ListenableFuture | |
import com.trueaccord.scalapb.grpc.Grpc | |
import io.grpc.stub.StreamObserver | |
import monix.eval.{Callback, Task} | |
import monix.execution.Ack.{Continue, Stop} | |
import monix.execution.{Ack, Scheduler} | |
import monix.reactive.observers.{BufferedSubscriber, Subscriber} | |
import monix.reactive.subjects.PublishToOneSubject | |
import monix.reactive.{Observable, OverflowStrategy} | |
import scala.concurrent.Future | |
package object grpcmonix { | |
def guavaFuture2Task[A](guavaFuture: ListenableFuture[A]): Task[A] = | |
Task.defer(Task.fromFuture(Grpc.guavaFuture2ScalaFuture(guavaFuture))) | |
def monixSubscriber2GrpcStreamObserver[T]( | |
subscriber: Subscriber[T]): StreamObserver[T] = new StreamObserver[T] { | |
val rSubscriber = | |
BufferedSubscriber[T](subscriber, OverflowStrategy.Unbounded) | |
override def onError(ex: Throwable): Unit = rSubscriber.onError(ex) | |
override def onCompleted(): Unit = rSubscriber.onComplete() | |
override def onNext(value: T): Unit = { | |
//The onNext method of the buffered returns an Ack synchronously. We don't have to worry about having | |
//to chain a Future[Ack] with some other computation | |
rSubscriber.onNext(value) | |
} | |
} | |
def grpcStreamObserver2monixSubscriber[T](observer: StreamObserver[T], | |
s: Scheduler) = new Subscriber[T] { | |
override implicit def scheduler: Scheduler = s | |
override def onError(ex: Throwable): Unit = observer.onError(ex) | |
override def onComplete(): Unit = observer.onCompleted() | |
override def onNext(elem: T): Future[Ack] = | |
try { | |
observer.onNext(elem) | |
Continue | |
} catch { | |
case ex: Throwable => observer.onError(ex); Stop | |
} | |
} | |
private def grpcOperator2MonixOperator[In, Out]( | |
grpcOperator: StreamObserver[Out] => StreamObserver[In]) | |
: Subscriber[Out] => Subscriber[In] = (subscriberOut: Subscriber[Out]) => { | |
val streamObserverOut: StreamObserver[Out] = | |
monixSubscriber2GrpcStreamObserver(subscriberOut) | |
val streamObserverIn: StreamObserver[In] = grpcOperator(streamObserverOut) | |
grpcStreamObserver2monixSubscriber(streamObserverIn, | |
subscriberOut.scheduler) | |
} | |
def liftByOperator[In, Out]( | |
operator: Subscriber[Out] => Subscriber[In], | |
obsIn: Observable[In] | |
): Observable[Out] = | |
obsIn.liftByOperator(operator) | |
def liftFromGrpcOperator[In, Out]( | |
operator: StreamObserver[Out] => StreamObserver[In], | |
obsIn: Observable[In] | |
): Observable[Out] = | |
liftByOperator(grpcOperator2MonixOperator(operator), obsIn) | |
def unliftByTransformer[In, Out]( | |
transformer: Observable[In] => Observable[Out], | |
obsOut: Subscriber[Out] | |
): Subscriber[In] = { | |
new Subscriber[In] { | |
private[this] val input = PublishToOneSubject[In]() | |
locally { | |
input.transform(transformer).subscribe(obsOut) | |
} | |
implicit val scheduler = obsOut.scheduler | |
def onNext(elem: In): Future[Ack] = | |
input.onNext(elem) | |
def onError(ex: Throwable): Unit = | |
input.onError(ex) | |
def onComplete(): Unit = | |
input.onComplete() | |
} | |
} | |
def grpcStreamObserver2Callback[T]( | |
observer: StreamObserver[T]): Callback[T] = new Callback[T] { | |
override def onError(ex: Throwable): Unit = { | |
observer.onError(ex) | |
} | |
override def onSuccess(value: T): Unit = { | |
observer.onNext(value) | |
observer.onCompleted() | |
} | |
} | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import com.google.protobuf.Descriptors.{MethodDescriptor, ServiceDescriptor} | |
import com.trueaccord.scalapb.compiler.FunctionalPrinter.PrinterEndo | |
import com.trueaccord.scalapb.compiler._ | |
import scala.collection.JavaConverters._ | |
final class MonixGrpcPrinter(service: ServiceDescriptor, | |
override val params: GeneratorParams) | |
extends DescriptorPimps { | |
private[this] def observable(typeParam: String): String = | |
s"Observable[$typeParam]" | |
private[this] def task(typeParam: String): String = | |
s"Task[$typeParam]" | |
private[this] def serviceMethodSignature(method: MethodDescriptor) = { | |
s"def ${method.name}" + (method.streamType match { | |
case StreamType.Unary => | |
s"(request: ${method.scalaIn}): ${task(method.scalaOut)}" | |
case StreamType.ClientStreaming => | |
s"(input: ${observable(method.scalaIn)}): ${task(method.scalaOut)}" | |
case StreamType.ServerStreaming => | |
s"(request: ${method.scalaIn}): ${observable(method.scalaOut)}" | |
case StreamType.Bidirectional => | |
s"(input: ${observable(method.scalaIn)}): ${observable(method.scalaOut)}" | |
}) | |
} | |
private[this] def serviceTrait: PrinterEndo = { | |
val endos: PrinterEndo = { p => | |
p.seq(service.methods.map(m => serviceMethodSignature(m) + "\n")) | |
} | |
{ p => | |
p.add(s"trait ${service.name} {").withIndent(endos).add("}") | |
} | |
} | |
private[this] val channel = "_root_.io.grpc.Channel" | |
private[this] val callOptions = "_root_.io.grpc.CallOptions" | |
private[this] val streamObserver = "_root_.io.grpc.stub.StreamObserver" | |
private[this] val abstractStub = "_root_.io.grpc.stub.AbstractStub" | |
private[this] val serverCalls = "_root_.io.grpc.stub.ServerCalls" | |
private[this] val clientCalls = "_root_.io.grpc.stub.ClientCalls" | |
private[this] val guavaFuture2Task = "grpcmonix.guavaFuture2Task" | |
private[this] val liftFromGrpcOperator = "grpcmonix.liftFromGrpcOperator" | |
private[this] val subscriber2Observer = "grpcmonix.monixSubscriber2GrpcStreamObserver" | |
private[this] val observer2Subscriber = "grpcmonix.grpcStreamObserver2monixSubscriber" | |
private[this] val observer2Callback = "grpcmonix.grpcStreamObserver2Callback" | |
private[this] val unliftByTransformer = "grpcmonix.unliftByTransformer" | |
private[this] val monixTask = "Task" | |
private[this] def clientMethodImpl(m: MethodDescriptor) = | |
PrinterEndo { p => | |
m.streamType match { | |
case StreamType.Unary => | |
p.addM( | |
s"""|override ${serviceMethodSignature(m)} = { | |
| $guavaFuture2Task($clientCalls.futureUnaryCall(channel.newCall(${m.descriptorName}, options), request)) | |
|}""" | |
) | |
case StreamType.ServerStreaming => | |
p.addM( | |
s"""|override ${serviceMethodSignature(m)} = Observable.create[${m.scalaOut}](OverflowStrategy.Unbounded){ | |
| subscriber: Subscriber[${m.scalaOut}] => | |
| val responseObserver = $subscriber2Observer(subscriber) | |
| $clientCalls.asyncServerStreamingCall( | |
| channel.newCall(${m.descriptorName}, options), | |
| request, | |
| responseObserver) | |
| Cancelable(() => {}) | |
|}""" | |
) | |
case StreamType.Bidirectional => | |
p.addM( | |
s"""|override ${serviceMethodSignature(m)} = { | |
| $liftFromGrpcOperator({responseObserver : $streamObserver[${m.scalaOut}] => | |
| $clientCalls.asyncBidiStreamingCall(channel.newCall(${m.descriptorName}, options), responseObserver) | |
| }, input) | |
|}""" | |
) | |
case StreamType.ClientStreaming => | |
p.addM( | |
s"""|override ${serviceMethodSignature(m)} = { | |
| $liftFromGrpcOperator({responseObserver : $streamObserver[${m.scalaOut}] => | |
| $clientCalls.asyncClientStreamingCall(channel.newCall(${m.descriptorName}, options), responseObserver) | |
| }, input).headL | |
|}""" | |
) | |
} | |
} andThen PrinterEndo { _.newline } | |
private def stubImplementation( | |
className: String, | |
baseClass: String, | |
methods: Seq[PrinterEndo] | |
): PrinterEndo = { p => | |
val build = | |
s" override def build(channel: $channel, options: $callOptions): $className = new $className(channel, options)" | |
p.add( | |
s"class $className(channel: $channel, options: $callOptions = $callOptions.DEFAULT) extends $abstractStub[$className](channel, options) with $baseClass {" | |
) | |
.withIndent( | |
methods: _* | |
) | |
.add( | |
build | |
) | |
.add( | |
"}" | |
) | |
} | |
private[this] val stub: PrinterEndo = { | |
val methods = service.getMethods.asScala.map(clientMethodImpl(_)) | |
stubImplementation(service.stub, service.name, methods) | |
} | |
private[this] def methodDescriptor(method: MethodDescriptor) = PrinterEndo { | |
p => | |
def marshaller(typeName: String) = | |
s"new com.trueaccord.scalapb.grpc.Marshaller($typeName)" | |
val methodType = method.streamType match { | |
case StreamType.Unary => "UNARY" | |
case StreamType.ClientStreaming => "CLIENT_STREAMING" | |
case StreamType.ServerStreaming => "SERVER_STREAMING" | |
case StreamType.Bidirectional => "BIDI_STREAMING" | |
} | |
val grpcMethodDescriptor = "_root_.io.grpc.MethodDescriptor" | |
p.addM( | |
s"""val ${method.descriptorName}: $grpcMethodDescriptor[${method.scalaIn}, ${method.scalaOut}] = | |
| $grpcMethodDescriptor.create( | |
| $grpcMethodDescriptor.MethodType.$methodType, | |
| $grpcMethodDescriptor.generateFullMethodName("${service.getFullName}", "${method.getName}"), | |
| ${marshaller(method.scalaIn)}, | |
| ${marshaller(method.scalaOut)}) | |
|""") | |
} | |
private[this] def addMethodImplementation( | |
method: MethodDescriptor): PrinterEndo = PrinterEndo { | |
_.add(".addMethod(") | |
.add(s" ${method.descriptorName},") | |
.withIndent(PrinterEndo { p => | |
val call = method.streamType match { | |
case StreamType.Unary => s"$serverCalls.asyncUnaryCall" | |
case StreamType.ClientStreaming => | |
s"$serverCalls.asyncClientStreamingCall" | |
case StreamType.ServerStreaming => | |
s"$serverCalls.asyncServerStreamingCall" | |
case StreamType.Bidirectional => | |
s"$serverCalls.asyncBidiStreamingCall" | |
} | |
val scheduler = "scheduler" | |
val serviceImpl = "serviceImpl" | |
method.streamType match { | |
case StreamType.Unary => | |
val serverMethod = | |
s"$serverCalls.UnaryMethod[${method.scalaIn}, ${method.scalaOut}]" | |
p.addM(s"""$call(new $serverMethod { | |
| override def invoke(request: ${method.scalaIn}, observer: $streamObserver[${method.scalaOut}]): Unit = | |
| $serviceImpl.${method.name}(request).runAsync($observer2Callback(observer))( | |
| $scheduler) | |
|}))""") | |
case StreamType.ServerStreaming => | |
val serverMethod = | |
s"$serverCalls.ServerStreamingMethod[${method.scalaIn}, ${method.scalaOut}]" | |
p.addM(s"""$call(new $serverMethod { | |
| override def invoke(request: ${method.scalaIn}, observer: $streamObserver[${method.scalaOut}]): Unit = | |
| $serviceImpl.${method.name}(request).subscribe($observer2Subscriber(observer, scheduler)) | |
|}))""") | |
case StreamType.Bidirectional => | |
//ClientStreamingMethod | |
val serverMethod = | |
s"$serverCalls.BidiStreamingMethod[${method.scalaIn}, ${method.scalaOut}]" | |
p.addM(s"""$call(new $serverMethod { | |
| override def invoke(observer: $streamObserver[${method.scalaOut}]): $streamObserver[${method.scalaIn}] = { | |
| val subscriberOut : Subscriber[${method.scalaOut}] = $observer2Subscriber(observer, scheduler) | |
| val subscriberIn : Subscriber[${method.scalaIn}] = $unliftByTransformer( | |
| observableIn => $serviceImpl.${method.name}(observableIn), | |
| subscriberOut | |
| ) | |
| $subscriber2Observer(subscriberIn) | |
| } | |
|}))""") | |
case StreamType.ClientStreaming => | |
val serverMethod = | |
s"$serverCalls.ClientStreamingMethod[${method.scalaIn}, ${method.scalaOut}]" | |
p.addM(s"""$call(new $serverMethod { | |
| override def invoke(observer: $streamObserver[${method.scalaOut}]): $streamObserver[${method.scalaIn}] = { | |
| val subscriberOut : Subscriber[${method.scalaOut}] = $observer2Subscriber(observer, scheduler) | |
| val subscriberIn : Subscriber[${method.scalaIn}] = $unliftByTransformer( | |
| observableIn => Observable.fromTask($serviceImpl.${method.name}(observableIn)), | |
| subscriberOut | |
| ) | |
| $subscriber2Observer(subscriberIn) | |
| } | |
|}))""") | |
} | |
}) | |
} | |
private[this] val bindService = { | |
val scheduler = "scheduler" | |
val methods = service.methods.map(addMethodImplementation) | |
val serverServiceDef = "_root_.io.grpc.ServerServiceDefinition" | |
PrinterEndo( | |
_.add( | |
s"""def bindService(serviceImpl: ${service.name}, $scheduler: monix.execution.Scheduler): $serverServiceDef =""") | |
.withIndent( | |
_.add(s"""$serverServiceDef.builder("${service.getFullName}")"""), | |
_.call(methods: _*), | |
_.add(".build()"))) | |
} | |
def printService(printer: FunctionalPrinter): FunctionalPrinter = { | |
printer | |
.add( | |
"package " + service.getFile.scalaPackageName, | |
"", | |
"import monix.execution.Cancelable", | |
"import monix.eval.Task", | |
"import monix.reactive.observers.Subscriber", | |
"import monix.reactive.{Observable, OverflowStrategy}", | |
"" | |
) | |
.call(serviceTrait) | |
.newline | |
.add( | |
s"object ${service.getName} extends grpcmonix.MonixGrpcServiceCompanion[${service.getName}] {" | |
) | |
.newline | |
.withIndent( | |
_.call(service.methods.map(methodDescriptor): _*), | |
_.newline, | |
stub, | |
_.newline, | |
bindService, | |
_.newline, | |
_.add( | |
s"def stub(channel: $channel): ${service.stub} = new ${service.stub}(channel)"), | |
_.newline, | |
_.add( | |
s"def descriptor: _root_.com.google.protobuf.Descriptors.ServiceDescriptor = ${service.getFile.fileDescriptorObjectFullName}.descriptor.getServices().get(${service.getIndex})") | |
) | |
.add("}") | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Hey, I'm interested in your gist - is it working and can it handle grpc flow control through monix backpressure? Also, how do you go about registering it in scalapb so that the compiler uses your generator? Thanks a lot!