Created
May 11, 2012 08:14
-
-
Save spin6lock/2658300 to your computer and use it in GitHub Desktop.
zmq rpc demo
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python | |
# -*- coding: UTF-8 -*- | |
import sys, os | |
import time | |
import itertools | |
from weakref import KeyedRef, ref, WeakValueDictionary | |
import traceback | |
import functools | |
import zmq | |
from zmq.green import Context | |
import gevent | |
from pickle import loads as pickle_loads, dumps as pickle_dumps, HIGHEST_PROTOCOL | |
#from cPickle import loads as pickle_loads, dumps as pickle_dumps | |
#import msgpack#cross programs | |
from msgpack import loads, dumps | |
from gevent import GreenletExit, spawn, spawn_later, Timeout, sleep | |
from gevent.pool import Pool | |
from gevent.queue import Queue | |
from gevent.event import AsyncResult | |
from logging import root as logger | |
PICKLE_PROTOCOL = HIGHEST_PROTOCOL | |
class RpcError(StandardError): | |
pass | |
class RpcFuncNoFound(RpcError): | |
pass | |
class RpcExportNoFound(RpcError): | |
pass | |
class RpcCallError(RpcError): | |
pass | |
class RpcRuntimeError(RpcError): | |
pass | |
class RpcTimeout(RpcError): | |
pass | |
class UnrecoverableError(RpcError): | |
pass | |
class RPCCloseError(RpcError): | |
pass | |
class FakeTimeout: | |
def __enter__(self): | |
pass | |
def __exit__(self, typ, value, tb): | |
pass | |
fake_timeout = FakeTimeout() | |
timeout_error = RpcTimeout() | |
_context = None | |
def get_context(): | |
global _context | |
if _context is None: | |
_context = Context(3) | |
return _context | |
#数据类型 | |
RT_REQUEST = 1 << 0 | |
RT_RESPONSE = 1 << 1 | |
RT_HEARTBEAT = 1 << 2 | |
RT_EXCEPTION = 1 << 3 | |
ST_NO_RESULT = 1 << 5 | |
ST_NO_MSG = 1 << 6 | |
DT_PICKLE = 1 << 7 #默认用msgpack | |
DT_ZIP = 1 << 8 | |
DT_PROXY = 1 << 9 #标示传递的第1个参数是obj, 需要转换成proxy | |
#type mark | |
RT_MARK = ST_NO_RESULT - 1 | |
RT_RESULTS = (RT_RESPONSE, RT_EXCEPTION) | |
HEARTBEAT_TIME = 30 #心跳30秒 | |
def log_except(): | |
ei = sys.exc_info() | |
lines = ['Traceback (most recent call last):\n'] | |
st = traceback.extract_stack(f=ei[2].tb_frame.f_back) | |
et = traceback.extract_tb(ei[2]) | |
lines.extend(traceback.format_list(st)) | |
lines.append(' ****** Traceback ****** \n') | |
lines.extend(traceback.format_list(et)) | |
lines.extend(traceback.format_exception_only(ei[0], ei[1])) | |
exc = ''.join(lines) | |
print(exc) | |
class AbsExport(object): | |
_rpc_name_ = '' #进程内唯一名称 | |
class RpcBase(object): | |
zmq_type = None | |
def __init__(self, context=None, size=None): | |
if context is None: | |
context = get_context() | |
self._socket = context.socket(self.zmq_type) | |
self._pool = Pool(size=size) | |
self._exports = {} | |
self._names = {} | |
self.disable = False | |
def remove(wr, selfref=ref(self)): | |
self = selfref() | |
if self is not None: | |
self._exports.pop(wr.key[0], None) | |
self._names.pop(wr.key[1], None) | |
self._remove = remove | |
def _resolve_endpoint(self, endpoints): | |
if not isinstance(endpoints, (tuple, list)): | |
yield endpoints | |
else: | |
for p in endpoints: | |
yield p | |
def bind(self, endpoints): | |
for endpoint in self._resolve_endpoint(endpoints): | |
self._socket.bind(endpoint) | |
def connect(self, endpoints): | |
for endpoint in self._resolve_endpoint(endpoints): | |
self._socket.connect(endpoint) | |
@property | |
def stoped(self): | |
return getattr(self, '_send_task', None) is None | |
def start(self): | |
if not self.stoped: | |
return | |
self.disable = False | |
self._send_queue = Queue(maxsize=0) | |
self._send_task = spawn(self._sender) | |
self._recv_task = spawn(self._recver) | |
def before_stop(self): | |
pass | |
def stop(self): | |
if self.stoped: | |
return | |
_send_task = self._send_task | |
self._send_task = None | |
self.before_stop() | |
self.disable = True | |
try: | |
self._pool.join(raise_error=True) | |
finally: | |
self._socket.close() | |
_send_task.kill(block=False) | |
self._recv_task.kill(block=False) | |
self._recv_task = None | |
def _sender(self): | |
running = True | |
_send = self._socket.send | |
for parts in self._send_queue: | |
for i in xrange(len(parts) - 1): | |
try: | |
_send(parts[i], flags=zmq.SNDMORE) | |
except GreenletExit: | |
if i == 0: | |
return | |
running = False | |
_send(parts[i], flags=zmq.SNDMORE) | |
self._socket.send(parts[-1]) | |
if not running: | |
return | |
def _recver(self): | |
running = True | |
_recv = self._socket.recv | |
while True: | |
parts = [] | |
while True: | |
try: | |
part = _recv() | |
except GreenletExit: | |
running = False | |
if len(parts) == 0: | |
return | |
part = _recv() | |
parts.append(part) | |
if not self._socket.getsockopt(zmq.RCVMORE): | |
break | |
if not running: | |
break | |
self._pool.spawn(self._handle, parts) | |
def send(self, parts): | |
self._socket.send_multipart(parts) | |
#self._send_queue.put(parts) | |
def _handle(self, parts): | |
""" 单个数据包处理 """ | |
pass | |
def proxy(self, obj): | |
""" 代理obj,可以通过pickle传递到远程 """ | |
key = self.register(obj) | |
def register(self, export): | |
export_id = id(export) | |
name = export._rpc_name_ | |
self._exports[export_id] = KeyedRef(export, self._remove, (export_id, name)) | |
self._names[name] = export_id | |
return export_id | |
def unregister(self, export): | |
export_id = id(export) | |
name = export._rpc_name_ | |
self._names.pop(name, None) | |
self._exports.pop(export_id, None) | |
class RpcService(object): | |
""" service for one socket """ | |
def __init__(self, svr, identity): | |
if 0: | |
self.svr = RpcServer() | |
self.svr = svr | |
self.identity = identity | |
## if self.identity is not None: | |
## self.send = functools.partial(self.send, self.identity) | |
self.iter_id = itertools.cycle(xrange(sys.maxint)) | |
self._resps = {} | |
self._proxys = WeakValueDictionary() | |
self.stoped = False | |
self.heart_timeout = False | |
self._heart_time = time.time() | |
self._heart_task = spawn(self.heartbeat) | |
def _stop_proxys(self): | |
if not len(self._proxys): | |
return | |
proxys = self._proxys.values() | |
self._proxys.clear() | |
for p in proxys: | |
p.on_close() | |
def remote_stop(self): | |
self.heart_timeout = True | |
spawn(self.stop) | |
def stop(self): | |
if self.stoped: | |
return | |
self.stoped = True | |
self._stop_proxys() | |
if not self.heart_timeout: | |
self.call(0, 'remote_stop', tuple(), None, timeout=20) | |
self._heart_task.kill(block=False) | |
self.svr.svc_stop(self) | |
def send(self, *args): | |
if self.identity: | |
self.svr.send((self.identity, dumps(args))) | |
else: | |
self.svr.send((dumps(args), )) | |
def _read_response(self, index, timeout): | |
rs = AsyncResult() | |
self._resps[index] = rs | |
resp = rs.wait(timeout) | |
self._resps.pop(index, None) | |
if rs.exception: | |
raise rs.exception | |
## if resp is None: | |
## raise RpcTimeout | |
return resp | |
def call(self, obj_id, name, args, kw, no_result=False, | |
timeout=60, pickle=False, proxy=False): | |
dtype = RT_REQUEST | |
if proxy: | |
obj = args[0] | |
obj_id = self.svr.register(obj) | |
args = (obj_id, ) + args[1:] | |
dtype |= DT_PROXY | |
if pickle: | |
dtype |= DT_PICKLE | |
argkw = pickle_dumps((args, kw), PICKLE_PROTOCOL) | |
else: | |
argkw = dumps((args, kw)) | |
if no_result: | |
dtype |= ST_NO_RESULT | |
index = self.iter_id.next() | |
self.send(dtype, obj_id, index, name, argkw) | |
if no_result: | |
return | |
result = self._read_response(index, timeout) | |
return result | |
def _handle_request(self, parts): | |
dtype, obj_id, index, name, argkw = parts | |
try: | |
obj = self.get_export(obj_id) | |
if obj is None: | |
raise RpcExportNoFound, obj_id | |
func = getattr(obj, name) | |
if func is None: | |
raise RpcFuncNoFound, name | |
#if dtype & DT_ZIP: | |
if dtype & DT_PICKLE: | |
args, kw = pickle_loads(argkw) | |
else: | |
args, kw = loads(argkw) | |
if dtype & DT_PROXY: | |
export_id = args[0] | |
proxy = self.get_proxy(export_id) | |
args = (proxy,) + args[1:] | |
rs = func(*args, **kw) if kw is not None else func(*args) | |
if dtype & ST_NO_RESULT: | |
return | |
try: | |
self.send(RT_RESPONSE, index, dumps(rs)) | |
except Exception: | |
self.send(RT_RESPONSE | DT_PICKLE, index, pickle_dumps(rs, PICKLE_PROTOCOL)) | |
except Exception as e: | |
log_except() | |
if dtype & ST_NO_RESULT or self.svr.stoped: | |
return | |
self.send(RT_EXCEPTION, index, pickle_dumps(e, PICKLE_PROTOCOL)) | |
def _handle_response(self, parts): | |
dtype, index, argkw = parts | |
try: | |
rs = self._resps.pop(index) | |
if dtype & DT_PICKLE: | |
result = pickle_loads(argkw) | |
else: | |
result = loads(argkw) | |
rs.set(result) | |
except KeyError: | |
pass | |
def _handle_exception(self, parts): | |
RT_EXCEPTION, index, error = parts | |
error = pickle_loads(error) | |
try: | |
rs = self._resps.pop(index) | |
rs.set_exception(error) | |
except KeyError: | |
pass | |
def _handle(self, parts): | |
parts = (parts[0], ) + loads(parts[1]) if len(parts) ==2 else loads(parts[0]) | |
rt = parts[0] & RT_MARK | |
if rt == RT_REQUEST: | |
self._handle_request(parts) | |
elif rt == RT_RESPONSE: | |
self._handle_response(parts) | |
elif rt == RT_EXCEPTION: | |
self._handle_exception(parts) | |
elif rt == RT_HEARTBEAT: | |
self._heart_time = time.time() | |
else: | |
raise ValueError, 'unknown data:%s' % msg | |
def heartbeat(self): | |
beat = RT_HEARTBEAT | |
btime = HEARTBEAT_TIME | |
while not self.stoped: | |
self.send(beat) | |
sleep(btime) | |
if self._heart_time + btime + 10 < time.time(): | |
self.heart_timeout = True | |
break | |
self.stop() | |
#######remote call############## | |
def get_export(self, export_id): | |
""" get export obj by export_name """ | |
if export_id == 0: | |
return self | |
try: | |
return self.svr._exports[export_id]() | |
except KeyError: | |
return None | |
def get_id_by_name(self, name): | |
return self.svr._names.get(name) | |
def get_proxy(self, export_id): | |
""" remote call: get export obj by id """ | |
try: | |
return self._proxys[export_id] | |
except KeyError: | |
proxy = RpcProxy(self, export_id) | |
self._proxys[export_id] = proxy | |
return proxy | |
def get_proxy_by_name(self, name): | |
export_id = self.call(0, 'get_id_by_name', (name, ), None) | |
if export_id: | |
return self.get_proxy(export_id) | |
class RpcClient(RpcBase): | |
zmq_type = zmq.XREQ | |
def __init__(self, context=None, size=None): | |
RpcBase.__init__(self, context=None, size=None) | |
self.svc = RpcService(self, None) | |
def _handle(self, parts): | |
""" 单个数据包处理 """ | |
self.svc._handle(parts) | |
def before_stop(self): | |
self.svc.stop() | |
def svc_stop(self, service): | |
if service == self.svc: | |
self.stop() | |
class RpcServer(RpcBase): | |
""" rpc services """ | |
zmq_type = zmq.XREP | |
def __init__(self, context=None, size=None): | |
RpcBase.__init__(self, context=None, size=None) | |
self._services = {} | |
def before_stop(self): | |
for svc in self._services.values(): | |
svc.stop() | |
def get_service(self, identity): | |
try: | |
return self._services[identity] | |
except KeyError: | |
self._services[identity] = RpcService(self, identity) | |
return self._services[identity] | |
def _handle(self, parts): | |
""" 单个数据包处理 """ | |
sc = self.get_service(parts[0]) | |
sc._handle(parts[1:]) | |
def svc_stop(self, service): | |
self._services.pop(service.identity, None) | |
class RpcProxy(object): | |
def __init__(self, svc, export_id): | |
if 0: | |
self._svc = RpcService() | |
self._svc = svc | |
self._id = export_id | |
self._closes = [] | |
def __getattr__(self, attribute): | |
def _func(*args, **kw): | |
no_result = kw.pop('_no_result', False) | |
timeout = kw.pop('_timeout', 60) | |
pickle = kw.pop('_pickle', False) | |
proxy = kw.pop('_proxy', False) | |
return self._svc.call(self._id, attribute, args, kw, | |
no_result=no_result, timeout=timeout, pickle=pickle, proxy=proxy) | |
return _func | |
def on_close(self): | |
for func in self._closes: | |
try: | |
func(self) | |
except StandardError: | |
log_except() | |
def sub_close(self, func): | |
if func not in self._closes: | |
self._closes.append(func) | |
def unsub_close(self, func): | |
if func in self._closes: | |
self._closes.remove(func) | |
def test(): | |
class Test(AbsExport): | |
_rpc_name_ = 'test' | |
def t1(self): | |
return 't1' | |
def t2(self, *args, **kw): | |
return (args, kw) | |
def t3(self): | |
raise ValueError, 't3' | |
def t4(self, proxy, a, b): | |
rs = proxy.t1() | |
return rs, a, b | |
def tt(self): | |
raise ValueError, 't4' | |
return 'no_result' | |
global HEARTBEAT_TIME | |
endpoint = 'tcp://127.0.0.1:8081' | |
HEARTBEAT_TIME = 1 | |
t1 = Test() | |
svr = RpcServer() | |
svr.bind(endpoint) | |
svr.register(t1) | |
svr.start() | |
def client_test(name): | |
client = RpcClient() | |
client.connect(endpoint) | |
client.start() | |
proxy = client.svc.get_proxy_by_name(t1._rpc_name_) | |
print name, proxy.t1() | |
print name, proxy.t2(1,2,3, a=1, b=2) | |
try: | |
proxy.t3() | |
except ValueError as e: | |
print name, e | |
print name, proxy.t4(t1, 'a', 'b', _proxy=True) | |
print name, proxy.tt(_no_result=True) | |
gevent.sleep(3) | |
## #print name, svr._exports | |
## print name, len(client.svc._proxys) | |
## del proxy | |
## #del t1 | |
## import gc | |
## gc.collect() | |
## #print name, svr._exports | |
## print name, len(client.svc._proxys) | |
client.stop() | |
clients = [] | |
clients.append(spawn(client_test, 'A')) | |
#clients.append(spawn(client_test, 'B')) | |
gevent.joinall(clients) | |
svr.stop() | |
def benchmark(): | |
class Test(AbsExport): | |
_rpc_name_ = 'test' | |
def echo(self, msg): | |
return msg | |
endpoint = 'tcp://127.0.0.1:8081' | |
t1 = Test() | |
svr = RpcServer() | |
svr.bind(endpoint) | |
svr.register(t1) | |
svr.start() | |
client = RpcClient() | |
client.connect(endpoint) | |
client.start() | |
proxy = client.svc.get_proxy_by_name(t1._rpc_name_) | |
msg = '*' * 10 | |
index = [0] | |
number = 1000 | |
def test_echo(): | |
rs = proxy.echo(msg) | |
assert rs == msg, 'echo error' | |
index[0] += 1 | |
import timeit | |
t = timeit.timeit(test_echo, number=number) | |
print 'total:%s %s per/sec' % (t, number/float(t)) | |
assert index[0] == number, 'index error' | |
client.stop() | |
svr.stop() | |
def main(): | |
#test() | |
benchmark() | |
if __name__ == '__main__': | |
main() | |
#------------------------------------------------------------------------------- | |
#------------------------------------------------------------------------------- | |
#------------------------------------------------------------------------------- |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment