Skip to content

Instantly share code, notes, and snippets.

@BYCHEN
Last active January 26, 2018 07:57
Show Gist options
  • Save BYCHEN/e1f54f6debf47e8b58ebf5172f8db910 to your computer and use it in GitHub Desktop.
Save BYCHEN/e1f54f6debf47e8b58ebf5172f8db910 to your computer and use it in GitHub Desktop.
support Retry and Client-Side Load Balance for gRPC
"""gRPC load balance channel module.
Reference by : https://github.com/justdoit0823/grpc-resolver/blob/master/grpcresolver/channel.py
grpcio version : 1.8.3
"""
import random
import grpc
from grpc import _common
from grpc._cython import cygrpc
from grpc._channel import (_UNARY_UNARY_INITIAL_DUE, _EMPTY_FLAGS)
from grpc._channel import (
_end_unary_response_blocking,
_channel_managed_call_management,
_handle_event,
_check_call_error,
_options,
_ChannelCallState,
_ChannelConnectivityState,
_start_unary_request,
_RPCState,
_UnaryUnaryMultiCallable)
class _UnaryUnaryMultiCallable(grpc.UnaryUnaryMultiCallable):
def __init__(self, lbchannel, method, request_serializer, response_deserializer):
# self._channel = lbchannel
self._lbchannel = lbchannel
self._method = method
self._request_serializer = request_serializer
self._response_deserializer = response_deserializer
def _prepare(self, request, timeout, metadata):
deadline, deadline_timespec, serialized_request, rendezvous = (_start_unary_request(request, timeout, self._request_serializer))
if serialized_request is None:
return None, None, None, None, rendezvous
else:
state = _RPCState(_UNARY_UNARY_INITIAL_DUE, None, None, None, None)
operations = (
cygrpc.SendInitialMetadataOperation(metadata, _EMPTY_FLAGS),
cygrpc.SendMessageOperation(serialized_request, _EMPTY_FLAGS),
cygrpc.SendCloseFromClientOperation(_EMPTY_FLAGS),
cygrpc.ReceiveInitialMetadataOperation(_EMPTY_FLAGS),
cygrpc.ReceiveMessageOperation(_EMPTY_FLAGS),
cygrpc.ReceiveStatusOnClientOperation(_EMPTY_FLAGS),)
return state, operations, deadline, deadline_timespec, None
def _blocking(self, request, timeout, metadata, credentials):
state, operations, deadline, deadline_timespec, rendezvous = self._prepare(request, timeout, metadata)
if rendezvous:
raise rendezvous
else:
call = self.call_method(state, operations, metadata, credentials, deadline_timespec)
return state, call, deadline
def call_method(self, state, operations, metadata, credentials, deadline_timespec):
completion_queue = cygrpc.CompletionQueue()
channel = self._lbchannel.select_channel()
print("query : {}".format(channel.target))
call = channel._channel.create_call(None, 0, completion_queue, self._method, None, deadline_timespec)
if credentials is not None:
call.set_credentials(credentials._credentials)
call_error = call.start_client_batch(operations, None)
_check_call_error(call_error, metadata)
_handle_event(completion_queue.poll(), state, self._response_deserializer)
return call
def __call__(self, request, timeout=None, metadata=None, credentials=None):
_query_times = 0
while self._lbchannel.isRetry(_query_times):
state, call, deadline = self._blocking(request, timeout, metadata, credentials)
if state.code == grpc.StatusCode.OK:
break
_query_times = _query_times + 1
return _end_unary_response_blocking(state, call, False, deadline)
def with_call(self, request, timeout=None, metadata=None, credentials=None):
state, call, deadline = self._blocking(request, timeout, metadata, credentials)
return _end_unary_response_blocking(state, call, True, deadline)
def future(self, request, timeout=None, metadata=None, credentials=None):
state, operations, deadline, deadline_timespec, rendezvous = self._prepare(request, timeout, metadata)
if rendezvous:
return rendezvous
else:
channel = self._lbchannel.select_channel()
call, drive_call = channel._managed_call(None, 0, self._method, None, deadline_timespec)
if credentials is not None:
call.set_credentials(credentials._credentials)
event_handler = _event_handler(state, call, self._response_deserializer)
with state.condition:
call_error = call.start_client_batch( cygrpc.Operations(operations), event_handler)
if call_error != cygrpc.CallError.ok:
_call_error_set_RPCstate(state, call_error, metadata)
return _Rendezvous(state, None, None, deadline)
drive_call()
return _Rendezvous(state, call, self._response_deserializer, deadline)
class Channel(object):
"""An object communicates between `LbChannel` and gRPC request."""
# __slots__ = ('target', 'channel', 'managed_call', 'connectivity_state')
def __init__(self, target, options=None, credentials=None):
options = options if options is not None else ()
self.target = target
self._channel = channel = cygrpc.Channel(_common.encode(target), _common.channel_args(_options(options)), credentials)
self._managed_call = _channel_managed_call_management( _ChannelCallState(channel))
self._connectivity_state = _ChannelConnectivityState(channel)
class LBChannels(grpc.Channel):
def __init__(self, target, credentials=None, options=None):
self.credentials = credentials
self.options = options
self._channels = self.__parserChannels(target)
self._cur_index = random.randint(0, self.length() - 1)
def __parserChannels(self, target):
addrs = []
for addr in target.split(","):
if addr and addr.strip():
print(addr)
addrs.append(Channel(addr, self.options, self.credentials))
if len(addrs) == 0:
raise ValueError('No channel.')
return addrs
def select_channel(self):
addrs = self._channels
addr_num = len(addrs)
if addr_num == 0:
raise ValueError('No channel.')
addr = addrs[self._cur_index % addr_num]
self._cur_index = (self._cur_index + 1) % addr_num
return addr
def isRetry(self, cur_count):
return self.length() != cur_count
def length(self):
return len(self._channels)
def unary_unary(self, method, request_serializer=None, response_deserializer=None):
return _UnaryUnaryMultiCallable(self, _common.encode(method), request_serializer, response_deserializer)
def subscribe(self, callback, try_to_connect=None):
raise NotImplementedError
def unsubscribe(self, callback):
raise NotImplementedError
def unary_stream(self, method, request_serializer=None, response_deserializer=None):
raise NotImplementedError
def stream_unary(self, method, request_serializer=None, response_deserializer=None):
raise NotImplementedError
def stream_stream(self, method, request_serializer=None, response_deserializer=None):
raise NotImplementedError
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment