Last active
January 26, 2018 07:57
-
-
Save BYCHEN/e1f54f6debf47e8b58ebf5172f8db910 to your computer and use it in GitHub Desktop.
support Retry and Client-Side Load Balance for gRPC
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""gRPC load balance channel module. | |
Reference by : https://github.com/justdoit0823/grpc-resolver/blob/master/grpcresolver/channel.py | |
grpcio version : 1.8.3 | |
""" | |
import random | |
import grpc | |
from grpc import _common | |
from grpc._cython import cygrpc | |
from grpc._channel import (_UNARY_UNARY_INITIAL_DUE, _EMPTY_FLAGS) | |
from grpc._channel import ( | |
_end_unary_response_blocking, | |
_channel_managed_call_management, | |
_handle_event, | |
_check_call_error, | |
_options, | |
_ChannelCallState, | |
_ChannelConnectivityState, | |
_start_unary_request, | |
_RPCState, | |
_UnaryUnaryMultiCallable) | |
class _UnaryUnaryMultiCallable(grpc.UnaryUnaryMultiCallable): | |
def __init__(self, lbchannel, method, request_serializer, response_deserializer): | |
# self._channel = lbchannel | |
self._lbchannel = lbchannel | |
self._method = method | |
self._request_serializer = request_serializer | |
self._response_deserializer = response_deserializer | |
def _prepare(self, request, timeout, metadata): | |
deadline, deadline_timespec, serialized_request, rendezvous = (_start_unary_request(request, timeout, self._request_serializer)) | |
if serialized_request is None: | |
return None, None, None, None, rendezvous | |
else: | |
state = _RPCState(_UNARY_UNARY_INITIAL_DUE, None, None, None, None) | |
operations = ( | |
cygrpc.SendInitialMetadataOperation(metadata, _EMPTY_FLAGS), | |
cygrpc.SendMessageOperation(serialized_request, _EMPTY_FLAGS), | |
cygrpc.SendCloseFromClientOperation(_EMPTY_FLAGS), | |
cygrpc.ReceiveInitialMetadataOperation(_EMPTY_FLAGS), | |
cygrpc.ReceiveMessageOperation(_EMPTY_FLAGS), | |
cygrpc.ReceiveStatusOnClientOperation(_EMPTY_FLAGS),) | |
return state, operations, deadline, deadline_timespec, None | |
def _blocking(self, request, timeout, metadata, credentials): | |
state, operations, deadline, deadline_timespec, rendezvous = self._prepare(request, timeout, metadata) | |
if rendezvous: | |
raise rendezvous | |
else: | |
call = self.call_method(state, operations, metadata, credentials, deadline_timespec) | |
return state, call, deadline | |
def call_method(self, state, operations, metadata, credentials, deadline_timespec): | |
completion_queue = cygrpc.CompletionQueue() | |
channel = self._lbchannel.select_channel() | |
print("query : {}".format(channel.target)) | |
call = channel._channel.create_call(None, 0, completion_queue, self._method, None, deadline_timespec) | |
if credentials is not None: | |
call.set_credentials(credentials._credentials) | |
call_error = call.start_client_batch(operations, None) | |
_check_call_error(call_error, metadata) | |
_handle_event(completion_queue.poll(), state, self._response_deserializer) | |
return call | |
def __call__(self, request, timeout=None, metadata=None, credentials=None): | |
_query_times = 0 | |
while self._lbchannel.isRetry(_query_times): | |
state, call, deadline = self._blocking(request, timeout, metadata, credentials) | |
if state.code == grpc.StatusCode.OK: | |
break | |
_query_times = _query_times + 1 | |
return _end_unary_response_blocking(state, call, False, deadline) | |
def with_call(self, request, timeout=None, metadata=None, credentials=None): | |
state, call, deadline = self._blocking(request, timeout, metadata, credentials) | |
return _end_unary_response_blocking(state, call, True, deadline) | |
def future(self, request, timeout=None, metadata=None, credentials=None): | |
state, operations, deadline, deadline_timespec, rendezvous = self._prepare(request, timeout, metadata) | |
if rendezvous: | |
return rendezvous | |
else: | |
channel = self._lbchannel.select_channel() | |
call, drive_call = channel._managed_call(None, 0, self._method, None, deadline_timespec) | |
if credentials is not None: | |
call.set_credentials(credentials._credentials) | |
event_handler = _event_handler(state, call, self._response_deserializer) | |
with state.condition: | |
call_error = call.start_client_batch( cygrpc.Operations(operations), event_handler) | |
if call_error != cygrpc.CallError.ok: | |
_call_error_set_RPCstate(state, call_error, metadata) | |
return _Rendezvous(state, None, None, deadline) | |
drive_call() | |
return _Rendezvous(state, call, self._response_deserializer, deadline) | |
class Channel(object): | |
"""An object communicates between `LbChannel` and gRPC request.""" | |
# __slots__ = ('target', 'channel', 'managed_call', 'connectivity_state') | |
def __init__(self, target, options=None, credentials=None): | |
options = options if options is not None else () | |
self.target = target | |
self._channel = channel = cygrpc.Channel(_common.encode(target), _common.channel_args(_options(options)), credentials) | |
self._managed_call = _channel_managed_call_management( _ChannelCallState(channel)) | |
self._connectivity_state = _ChannelConnectivityState(channel) | |
class LBChannels(grpc.Channel): | |
def __init__(self, target, credentials=None, options=None): | |
self.credentials = credentials | |
self.options = options | |
self._channels = self.__parserChannels(target) | |
self._cur_index = random.randint(0, self.length() - 1) | |
def __parserChannels(self, target): | |
addrs = [] | |
for addr in target.split(","): | |
if addr and addr.strip(): | |
print(addr) | |
addrs.append(Channel(addr, self.options, self.credentials)) | |
if len(addrs) == 0: | |
raise ValueError('No channel.') | |
return addrs | |
def select_channel(self): | |
addrs = self._channels | |
addr_num = len(addrs) | |
if addr_num == 0: | |
raise ValueError('No channel.') | |
addr = addrs[self._cur_index % addr_num] | |
self._cur_index = (self._cur_index + 1) % addr_num | |
return addr | |
def isRetry(self, cur_count): | |
return self.length() != cur_count | |
def length(self): | |
return len(self._channels) | |
def unary_unary(self, method, request_serializer=None, response_deserializer=None): | |
return _UnaryUnaryMultiCallable(self, _common.encode(method), request_serializer, response_deserializer) | |
def subscribe(self, callback, try_to_connect=None): | |
raise NotImplementedError | |
def unsubscribe(self, callback): | |
raise NotImplementedError | |
def unary_stream(self, method, request_serializer=None, response_deserializer=None): | |
raise NotImplementedError | |
def stream_unary(self, method, request_serializer=None, response_deserializer=None): | |
raise NotImplementedError | |
def stream_stream(self, method, request_serializer=None, response_deserializer=None): | |
raise NotImplementedError |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment