Last active
April 15, 2020 10:35
-
-
Save RussellLuo/9ee9585e3c2b0dbd0298574c241e1bcf to your computer and use it in GitHub Desktop.
gRPC client interface for Python: generation script and mocking class.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# -*- coding=utf-8 -*- | |
"""Generate a pythonic interface based on the code generated by `grpcio-tools`. | |
Example: | |
$ python grpc_pi.py --proto-package-name='xx' --pb2-module-name='python.path.xx_pb2' | |
""" | |
import argparse | |
import itertools | |
import re | |
import sys | |
from collections import OrderedDict | |
from importlib import import_module | |
import grpc | |
class Generator(object): | |
writer = sys.stdout | |
def __init__(self, proto_package_name, pb2_module_name, | |
core_method_name, unfold_method_args, rpc_method_args_size): | |
self.proto_package_name = proto_package_name | |
self.pb2_module_name = pb2_module_name | |
self.core_method_name = core_method_name | |
self.unfold_method_args = unfold_method_args | |
self.rpc_method_args_size = rpc_method_args_size | |
self.stub_class_name = self.camelize(self.proto_package_name) + 'Stub' | |
if '.' in self.pb2_module_name: | |
self.pb2_path, self.pb2_name = self.pb2_module_name.rsplit('.', 1) | |
else: | |
self.pb2_path, self.pb2_name = '', self.pb2_module_name | |
self.pb2_module = import_module(self.pb2_module_name) | |
self.sym_db_pool = self.pb2_module._sym_db.pool | |
@staticmethod | |
def slice_every(iterable, n, padding=False, padding_item=None): | |
"""Return a list with at most `n` items each time from the `iterable`.""" | |
iterable = iter(iterable) | |
while True: | |
piece = list(itertools.islice(iterable, n)) | |
if not piece: | |
return | |
padding_len = n - len(piece) | |
if padding_len and padding: | |
piece.extend([padding_item] * padding_len) | |
yield piece | |
@staticmethod | |
def camelize(string, uppercase_first_letter=True): | |
"""Convert strings to CamelCase. | |
Borrowed from https://github.com/jpvanhal/inflection/blob/master/inflection.py | |
""" | |
if uppercase_first_letter: | |
return re.sub(r"(?:^|_)(.)", lambda m: m.group(1).upper(), string) | |
else: | |
return string[0].lower() + Generator.camelize(string)[1:] | |
@staticmethod | |
def underscore(word): | |
"""Make an underscored, lowercase form from the expression | |
in the string. | |
Borrowed from https://github.com/jpvanhal/inflection/blob/master/inflection.py | |
""" | |
word = re.sub(r"([A-Z]+)([A-Z][a-z])", r'\1_\2', word) | |
word = re.sub(r"([a-z\d])([A-Z])", r'\1_\2', word) | |
word = word.replace("-", "_") | |
return word.lower() | |
def has_enum_types(self): | |
return any(name.startswith(self.proto_package_name) | |
for name in self.sym_db_pool._enum_descriptors) | |
def write_module_header(self): | |
if self.pb2_path: | |
import_pb2 = 'from {pb2_path} import {pb2_name}'.format( | |
pb2_path=self.pb2_path, | |
pb2_name=self.pb2_name | |
) | |
else: | |
import_pb2 = 'import {pb2_name}'.format(pb2_name=self.pb2_name) | |
self.writer.write( | |
'# -*- coding: utf-8 -*-\n' | |
'{import_enum}' | |
'\nimport grpc' | |
'\n\n{import_pb2}'.format( | |
import_enum='\nimport enum' if self.has_enum_types() else '', | |
import_pb2=import_pb2 | |
) | |
) | |
def write_enum_types(self): | |
for name, enum in self.sym_db_pool._enum_descriptors.iteritems(): | |
if name.startswith(self.proto_package_name): | |
values = '\n'.join( | |
' {name} = {number}'.format(name=value.name, | |
number=value.number) | |
for value in enum.values | |
) | |
self.writer.write( | |
'\n\n\nclass {enum_name}(enum.Enum):\n' | |
'{values}'.format(enum_name=enum.name, values=values) | |
) | |
def write_message_types(self): | |
self.writer.write('\n\n') | |
for name, message in self.sym_db_pool._descriptors.iteritems(): | |
if name.startswith(self.proto_package_name): | |
self.writer.write( | |
'\n{name} = {pb2_name}.{name}'.format( | |
name=message.name, | |
pb2_name=self.pb2_name | |
) | |
) | |
def write_class_header(self): | |
class_prefix = self.camelize(self.proto_package_name) | |
self.writer.write( | |
'\n\n\nclass {}Interface(object):\n'.format(class_prefix) | |
) | |
def write_class_constructor(self): | |
self.writer.write( | |
'\n def __init__(self, target, timeout=10):' | |
'\n self.target = target' | |
'\n self.timeout = timeout\n' | |
) | |
def write_stub_property(self): | |
self.writer.write( | |
'\n @property\n' | |
' def stub(self):\n' | |
' channel = grpc.insecure_channel(self.target)\n' | |
' return {pb2_name}.{stub_class_name}(channel)\n'.format( | |
pb2_name=self.pb2_name, | |
stub_class_name=self.stub_class_name | |
) | |
) | |
def write_core_method(self): | |
self.writer.write( | |
'\n def {core_method_name}(self, rpc_name, req):\n' | |
' rpc = getattr(self.stub, rpc_name)\n' | |
' resp = rpc(req, self.timeout)\n' | |
' return resp\n'.format( | |
core_method_name=self.core_method_name, | |
pb2_name=self.pb2_name | |
) | |
) | |
def write_folded_rpc_method(self, method_name, req_name): | |
self.writer.write( | |
"\n def {underscored_method_name}(self, {req_name}):\n" | |
" resp = self.{core_method_name}('{method_name}', {req_name})\n" | |
" return resp\n".format( | |
underscored_method_name=self.underscore(method_name), | |
req_name=self.underscore(req_name), | |
core_method_name=self.core_method_name, | |
method_name=method_name | |
) | |
) | |
def write_unfolded_rpc_method(self, method_name, req_name, req_param_names): | |
indented_header = ' def {}('.format(self.underscore(method_name)) | |
full_params = ['self'] + req_param_names | |
args_size = self.rpc_method_args_size or len(full_params) | |
separator = ',\n' + len(indented_header) * ' ' | |
indented_params = separator.join( | |
', '.join(params) | |
for params in self.slice_every(full_params, args_size) | |
) | |
indented_kwargs = ',\n'.join( | |
' {0}={0}'.format(param_name) | |
for param_name in req_param_names | |
) | |
indented_body = ( | |
" req = {req_name}(\n" | |
"{indented_kwargs}\n" | |
" )\n" | |
" resp = self.{core_method_name}('{method_name}', req)\n" | |
" return resp\n".format( | |
req_name=req_name, | |
indented_kwargs=indented_kwargs, | |
core_method_name=self.core_method_name, | |
method_name=method_name | |
) | |
) | |
self.writer.write( | |
'\n{indented_header}' | |
'{indented_params}):\n' | |
'{indented_body}'.format( | |
indented_header=indented_header, | |
indented_params=indented_params, | |
indented_body=indented_body | |
) | |
) | |
def write_rpc_methods(self): | |
stub_class = getattr(self.pb2_module, self.stub_class_name) | |
channel = grpc.insecure_channel('localhost') | |
stub = stub_class(channel) | |
stub_method_names = [ | |
attr | |
for attr in dir(stub) | |
if not attr.startswith('__') | |
] | |
stub_method_names.sort() | |
stub_methods = OrderedDict([ | |
(stub_method_name, getattr(stub, stub_method_name)) | |
for stub_method_name in stub_method_names | |
]) | |
for stub_method_name, stub_method in stub_methods.iteritems(): | |
req_class = stub_method._request_serializer.im_class | |
req_name = req_class.__name__ | |
req_param_names = [ | |
self.underscore(field.name) | |
for field in req_class.DESCRIPTOR.fields | |
] | |
if self.unfold_method_args: | |
self.write_unfolded_rpc_method(stub_method_name, req_name, | |
req_param_names) | |
else: | |
self.write_folded_rpc_method(stub_method_name, req_name) | |
def generate(self): | |
self.write_module_header() | |
self.write_enum_types() | |
self.write_message_types() | |
self.write_class_header() | |
self.write_class_constructor() | |
self.write_stub_property() | |
self.write_core_method() | |
self.write_rpc_methods() | |
def main(): | |
parser = argparse.ArgumentParser() | |
parser.add_argument('--proto-package-name', required=True, | |
help='The package name of the proto file.') | |
parser.add_argument('--pb2-module-name', required=True, | |
help='The name of the generated `xx_pdb2.py` ' | |
'module with the full Python path.') | |
parser.add_argument('--core-method-name', default='call_rpc', | |
help='The name of the core method that will be ' | |
'used to call the actual rpc methods.') | |
parser.add_argument('--unfold-method-args', action='store_true', | |
help='Whether or not to unfold the request ' | |
'attributes as the arguments of each rpc method.') | |
parser.add_argument('--rpc-method-args-size', type=int, default=0, | |
help='The number of arguments per line in the ' | |
'definition of each rpc method.') | |
args = parser.parse_args() | |
generator = Generator(args.proto_package_name, | |
args.pb2_module_name, | |
args.core_method_name, | |
args.unfold_method_args, | |
args.rpc_method_args_size) | |
generator.generate() | |
if __name__ == '__main__': | |
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# -*- coding=utf-8 -*- | |
"""A mocking gRPC client interface for tests. | |
Usage: | |
1. Suppose you want to test the RPC function `create_user` of gRPC | |
client interface `UserInterface`, then the test may look like: | |
``` | |
import unittest | |
from user_module import UserInterface | |
class TestUserInterface(unittest.TestCase): | |
user_inter = UserInterface(...) | |
def test_create_user(self): | |
result = self.user_inter.create_user(name='russellluo') | |
self.assertEqual(result, 0) | |
``` | |
The above test will pass if the real gRPC server is running, but | |
running a gRPC server for tests is cumbersome. | |
2. Use `Mocker` to simply the test environment | |
``` | |
import unittest | |
from user_module import UserInterface | |
from mocker_module import Mocker | |
class TestUserInterface(unittest.TestCase): | |
user_inter = Mocker(UserInterface(...)) | |
def test_create_user(self): | |
result = self.user_inter.create_user(name='russellluo') | |
self.assertEqual(result, 0) | |
``` | |
1) First, run the test once, by interacting with the real gRPC server | |
2) Then, you will find that the class `Mocker` in the module file | |
`mocker_module` is changed magically | |
3) Afterwards, you can run the test alone as many times as you wish, | |
and the test will always pass without any interaction with the real | |
gRPC server | |
Internals: | |
When interacting with the real gRPC server, the mocking class | |
`Mocker` can automatically record the real input/output data of | |
each call of each RPC function, and finally `Mocker` will change | |
itself to be a complete replacement for the real gRPC interface. | |
""" | |
import cPickle | |
import functools | |
import os | |
import pprint | |
from collections import defaultdict | |
def record(data, method): | |
@functools.wraps(method) | |
def decorator(*args, **kwargs): | |
params = cPickle.dumps(args) + cPickle.dumps(kwargs) | |
method_data = data[method.__name__] | |
if params not in method_data: | |
result = method(*args, **kwargs) | |
method_data[params] = cPickle.dumps(result) | |
return cPickle.loads(method_data[params]) | |
return decorator | |
class Mocker(object): | |
def __init__(self, target): | |
self._data = getattr(self, '_fake_data', defaultdict(dict)) | |
for attr_name in dir(target): | |
if not attr_name.startswith('_'): | |
target_method = getattr(target, attr_name) | |
recordable_func = record(self._data, target_method) | |
setattr(self, attr_name, recordable_func) | |
def __del__(self): | |
mocker_file = os.path.abspath(__file__) | |
with open(mocker_file, 'r') as f: | |
content = f.read() | |
with open(mocker_file, 'w') as f: | |
fake_data_comment = '# Generated fake data' | |
mocker_body = ''.join(content.rpartition(fake_data_comment)[:-1]) | |
f.write(mocker_body) | |
fake_data_string = pprint.pformat(dict(self._data)) | |
f.write('\n _fake_data = ' + fake_data_string) | |
# Generated fake data |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment