Created
October 24, 2020 17:04
-
-
Save w4rum/4e20ec18b9065b1b6780e2f92ac4b6f0 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
syntax = "proto3"; | |
package example_service; | |
service ExampleService { | |
rpc ExampleUnaryUnary(ExampleRequest) returns (ExampleResponse); | |
rpc ExampleUnaryStream(ExampleRequest) returns (stream ExampleResponse); | |
rpc ExampleStreamUnary(stream ExampleRequest) returns (ExampleResponse); | |
rpc ExampleStreamStream(stream ExampleRequest) returns (stream ExampleResponse); | |
} | |
message ExampleRequest { | |
string example_string = 1; | |
} | |
message ExampleResponse { | |
string example_string = 1; | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Original gist by nat-n: | |
https://gist.github.com/nat-n/e90097ebfb861cbb25e20b68bec0e39c | |
""" | |
import inspect | |
from collections import AsyncIterator | |
import grpclib | |
import typing | |
import typing_inspect | |
from grpclib.const import Cardinality | |
from grpclib.server import Stream | |
def rpc_unary_unary(*, name=None): | |
return _rpc_method(name=name, cardinality=Cardinality.UNARY_UNARY) | |
def rpc_unary_stream(*, name=None): | |
return _rpc_method(name=name, cardinality=Cardinality.UNARY_STREAM) | |
def rpc_stream_unary(*, name=None): | |
return _rpc_method(name=name, cardinality=Cardinality.STREAM_UNARY) | |
def rpc_stream_stream(*, name=None): | |
return _rpc_method(name=name, cardinality=Cardinality.STREAM_STREAM) | |
def _rpc_method(*, name=None, cardinality=Cardinality.UNARY_UNARY): | |
def inner_decorator(func): | |
signature = typing.get_type_hints(func) | |
request_type = signature["request"] | |
response_type = signature["return"] | |
# roughly check that type hints match cardinality | |
if cardinality.server_streaming: | |
assert typing_inspect.get_origin(response_type) == AsyncIterator, \ | |
"streaming response type hint is not AsyncIterator" | |
response_type = typing_inspect.get_args(response_type)[0] | |
if cardinality.client_streaming: | |
assert typing_inspect.get_origin(request_type) == AsyncIterator, \ | |
"streaming request type hint is not AsyncIterator" | |
request_type = typing_inspect.get_args(request_type)[0] | |
async def wrapper(self, stream: grpclib.server.Stream[request_type, | |
response_type]): | |
# get request from client | |
if not cardinality.client_streaming: | |
request = await stream.recv_message() | |
else: | |
async def request_iterator() \ | |
-> typing.AsyncIterator[request_type]: | |
async for request_message in stream: | |
yield request_message | |
request = request_iterator() | |
# execute handler and send response back to client | |
if not cardinality.server_streaming: | |
response = await func(self, request) | |
await stream.send_message(response) | |
else: | |
response_iter = func(self, request) | |
# check if response is actually an AsyncIterator | |
# this might be false if the method just returns without | |
# yielding at least once | |
# in that case, we just interpret it as an empty iterator | |
if isinstance(response_iter, AsyncIterator): | |
async for response_message in response_iter: | |
await stream.send_message(response_message) | |
else: | |
response_iter.close() | |
wrapper.__rpc_method__ = { | |
"request_type": request_type, | |
"response_type": response_type, | |
"cardinality": cardinality, | |
"name": name or func.__name__, | |
} | |
assert wrapper.__rpc_method__["request_type"], \ | |
"request type must be annotated" | |
assert wrapper.__rpc_method__["response_type"], \ | |
"response type must be annotated" | |
return wrapper | |
return inner_decorator | |
def is_rpc_method(value): | |
""" | |
Check if the given value is a function annotated with __rpc_method__ | |
""" | |
return callable(value) and isinstance( | |
getattr(value, "__rpc_method__", None), dict) | |
class ServiceStub: | |
def __mapping__(self): | |
return { | |
f"/{self.service_name}/{method.__rpc_method__.get('name')}": | |
grpclib.const.Handler( | |
method, | |
method.__rpc_method__["cardinality"], | |
method.__rpc_method__["request_type"], | |
method.__rpc_method__["response_type"], | |
) | |
for method_name, method in inspect.getmembers(self, is_rpc_method) | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import asyncio | |
from typing import AsyncIterator | |
from grpclib.server import Server | |
from .magic_glue import rpc_unary_unary, rpc_unary_stream, ServiceStub, \ | |
rpc_stream_stream, rpc_stream_unary | |
from ..protobuf.example_service import ExampleRequest, ExampleResponse | |
class ExampleService(ServiceStub): | |
service_name = "example_service.ExampleService" | |
@rpc_unary_unary(name="ExampleUnaryUnary") | |
async def example_unary_unary(self, request: ExampleRequest) \ | |
-> ExampleResponse: | |
return ExampleResponse(example_string=request.example_string) | |
@rpc_unary_stream(name="ExampleUnaryStream") | |
async def example_unary_stream(self, request: ExampleRequest) \ | |
-> AsyncIterator[ExampleResponse]: | |
yield ExampleResponse(example_string=request.example_string) | |
yield ExampleResponse(example_string=request.example_string) | |
yield ExampleResponse(example_string=request.example_string) | |
@rpc_stream_unary(name="ExampleStreamUnary") | |
async def example_stream_unary(self, | |
request: AsyncIterator[ExampleRequest]) \ | |
-> ExampleResponse: | |
async for request_message in request: | |
# just return on the first message | |
return ExampleResponse( | |
example_string=request_message.example_string) | |
@rpc_stream_stream(name="ExampleStreamStream") | |
async def example_stream_stream(self, | |
request: AsyncIterator[ExampleRequest]) \ | |
-> AsyncIterator[ExampleResponse]: | |
async for request_message in request: | |
yield ExampleResponse( | |
example_string=request_message.example_string) | |
async def run_rpc_server(): | |
server = Server([ExampleService()]) | |
await server.start("localhost", 50051) | |
await server.wait_closed() | |
if __name__ == '__main__': | |
asyncio.run(run_rpc_server()) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import asyncio | |
from grpclib.client import Channel | |
from ..protobuf.example_service import ExampleServiceStub, ExampleRequest | |
async def main(): | |
channel = Channel("localhost", 50051) | |
example_service = ExampleServiceStub(channel) | |
print("-- Start") | |
print("Unary Unary:") | |
print(await example_service.example_unary_unary( | |
example_string="TEST UNARY UNARY")) | |
print("Unary Stream:") | |
async for response_message in example_service.example_unary_stream( | |
example_string="TEST UNARY STREAM"): | |
print(response_message) | |
async def stream_request(): | |
for i in range(10): | |
yield ExampleRequest(example_string=f"TEST STREAM UNARY {i}") | |
print("Stream Unary:") | |
print(await example_service.example_stream_unary(stream_request())) | |
print("Stream Stream:") | |
async for response_message in example_service.example_stream_stream( | |
stream_request()): | |
print(response_message) | |
channel.close() | |
print("-- Done") | |
asyncio.run(main()) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment