Skip to content

Instantly share code, notes, and snippets.

@spin6lock
Created May 11, 2012 08:14
Show Gist options
  • Save spin6lock/2658300 to your computer and use it in GitHub Desktop.
Save spin6lock/2658300 to your computer and use it in GitHub Desktop.
zmq rpc demo
#!/usr/bin/env python
# -*- coding: UTF-8 -*-
import sys, os
import time
import itertools
from weakref import KeyedRef, ref, WeakValueDictionary
import traceback
import functools
import zmq
from zmq.green import Context
import gevent
from pickle import loads as pickle_loads, dumps as pickle_dumps, HIGHEST_PROTOCOL
#from cPickle import loads as pickle_loads, dumps as pickle_dumps
#import msgpack#cross programs
from msgpack import loads, dumps
from gevent import GreenletExit, spawn, spawn_later, Timeout, sleep
from gevent.pool import Pool
from gevent.queue import Queue
from gevent.event import AsyncResult
from logging import root as logger
PICKLE_PROTOCOL = HIGHEST_PROTOCOL
class RpcError(StandardError):
pass
class RpcFuncNoFound(RpcError):
pass
class RpcExportNoFound(RpcError):
pass
class RpcCallError(RpcError):
pass
class RpcRuntimeError(RpcError):
pass
class RpcTimeout(RpcError):
pass
class UnrecoverableError(RpcError):
pass
class RPCCloseError(RpcError):
pass
class FakeTimeout:
def __enter__(self):
pass
def __exit__(self, typ, value, tb):
pass
fake_timeout = FakeTimeout()
timeout_error = RpcTimeout()
_context = None
def get_context():
global _context
if _context is None:
_context = Context(3)
return _context
#数据类型
RT_REQUEST = 1 << 0
RT_RESPONSE = 1 << 1
RT_HEARTBEAT = 1 << 2
RT_EXCEPTION = 1 << 3
ST_NO_RESULT = 1 << 5
ST_NO_MSG = 1 << 6
DT_PICKLE = 1 << 7 #默认用msgpack
DT_ZIP = 1 << 8
DT_PROXY = 1 << 9 #标示传递的第1个参数是obj, 需要转换成proxy
#type mark
RT_MARK = ST_NO_RESULT - 1
RT_RESULTS = (RT_RESPONSE, RT_EXCEPTION)
HEARTBEAT_TIME = 30 #心跳30秒
def log_except():
ei = sys.exc_info()
lines = ['Traceback (most recent call last):\n']
st = traceback.extract_stack(f=ei[2].tb_frame.f_back)
et = traceback.extract_tb(ei[2])
lines.extend(traceback.format_list(st))
lines.append(' ****** Traceback ****** \n')
lines.extend(traceback.format_list(et))
lines.extend(traceback.format_exception_only(ei[0], ei[1]))
exc = ''.join(lines)
print(exc)
class AbsExport(object):
_rpc_name_ = '' #进程内唯一名称
class RpcBase(object):
zmq_type = None
def __init__(self, context=None, size=None):
if context is None:
context = get_context()
self._socket = context.socket(self.zmq_type)
self._pool = Pool(size=size)
self._exports = {}
self._names = {}
self.disable = False
def remove(wr, selfref=ref(self)):
self = selfref()
if self is not None:
self._exports.pop(wr.key[0], None)
self._names.pop(wr.key[1], None)
self._remove = remove
def _resolve_endpoint(self, endpoints):
if not isinstance(endpoints, (tuple, list)):
yield endpoints
else:
for p in endpoints:
yield p
def bind(self, endpoints):
for endpoint in self._resolve_endpoint(endpoints):
self._socket.bind(endpoint)
def connect(self, endpoints):
for endpoint in self._resolve_endpoint(endpoints):
self._socket.connect(endpoint)
@property
def stoped(self):
return getattr(self, '_send_task', None) is None
def start(self):
if not self.stoped:
return
self.disable = False
self._send_queue = Queue(maxsize=0)
self._send_task = spawn(self._sender)
self._recv_task = spawn(self._recver)
def before_stop(self):
pass
def stop(self):
if self.stoped:
return
_send_task = self._send_task
self._send_task = None
self.before_stop()
self.disable = True
try:
self._pool.join(raise_error=True)
finally:
self._socket.close()
_send_task.kill(block=False)
self._recv_task.kill(block=False)
self._recv_task = None
def _sender(self):
running = True
_send = self._socket.send
for parts in self._send_queue:
for i in xrange(len(parts) - 1):
try:
_send(parts[i], flags=zmq.SNDMORE)
except GreenletExit:
if i == 0:
return
running = False
_send(parts[i], flags=zmq.SNDMORE)
self._socket.send(parts[-1])
if not running:
return
def _recver(self):
running = True
_recv = self._socket.recv
while True:
parts = []
while True:
try:
part = _recv()
except GreenletExit:
running = False
if len(parts) == 0:
return
part = _recv()
parts.append(part)
if not self._socket.getsockopt(zmq.RCVMORE):
break
if not running:
break
self._pool.spawn(self._handle, parts)
def send(self, parts):
self._socket.send_multipart(parts)
#self._send_queue.put(parts)
def _handle(self, parts):
""" 单个数据包处理 """
pass
def proxy(self, obj):
""" 代理obj,可以通过pickle传递到远程 """
key = self.register(obj)
def register(self, export):
export_id = id(export)
name = export._rpc_name_
self._exports[export_id] = KeyedRef(export, self._remove, (export_id, name))
self._names[name] = export_id
return export_id
def unregister(self, export):
export_id = id(export)
name = export._rpc_name_
self._names.pop(name, None)
self._exports.pop(export_id, None)
class RpcService(object):
""" service for one socket """
def __init__(self, svr, identity):
if 0:
self.svr = RpcServer()
self.svr = svr
self.identity = identity
## if self.identity is not None:
## self.send = functools.partial(self.send, self.identity)
self.iter_id = itertools.cycle(xrange(sys.maxint))
self._resps = {}
self._proxys = WeakValueDictionary()
self.stoped = False
self.heart_timeout = False
self._heart_time = time.time()
self._heart_task = spawn(self.heartbeat)
def _stop_proxys(self):
if not len(self._proxys):
return
proxys = self._proxys.values()
self._proxys.clear()
for p in proxys:
p.on_close()
def remote_stop(self):
self.heart_timeout = True
spawn(self.stop)
def stop(self):
if self.stoped:
return
self.stoped = True
self._stop_proxys()
if not self.heart_timeout:
self.call(0, 'remote_stop', tuple(), None, timeout=20)
self._heart_task.kill(block=False)
self.svr.svc_stop(self)
def send(self, *args):
if self.identity:
self.svr.send((self.identity, dumps(args)))
else:
self.svr.send((dumps(args), ))
def _read_response(self, index, timeout):
rs = AsyncResult()
self._resps[index] = rs
resp = rs.wait(timeout)
self._resps.pop(index, None)
if rs.exception:
raise rs.exception
## if resp is None:
## raise RpcTimeout
return resp
def call(self, obj_id, name, args, kw, no_result=False,
timeout=60, pickle=False, proxy=False):
dtype = RT_REQUEST
if proxy:
obj = args[0]
obj_id = self.svr.register(obj)
args = (obj_id, ) + args[1:]
dtype |= DT_PROXY
if pickle:
dtype |= DT_PICKLE
argkw = pickle_dumps((args, kw), PICKLE_PROTOCOL)
else:
argkw = dumps((args, kw))
if no_result:
dtype |= ST_NO_RESULT
index = self.iter_id.next()
self.send(dtype, obj_id, index, name, argkw)
if no_result:
return
result = self._read_response(index, timeout)
return result
def _handle_request(self, parts):
dtype, obj_id, index, name, argkw = parts
try:
obj = self.get_export(obj_id)
if obj is None:
raise RpcExportNoFound, obj_id
func = getattr(obj, name)
if func is None:
raise RpcFuncNoFound, name
#if dtype & DT_ZIP:
if dtype & DT_PICKLE:
args, kw = pickle_loads(argkw)
else:
args, kw = loads(argkw)
if dtype & DT_PROXY:
export_id = args[0]
proxy = self.get_proxy(export_id)
args = (proxy,) + args[1:]
rs = func(*args, **kw) if kw is not None else func(*args)
if dtype & ST_NO_RESULT:
return
try:
self.send(RT_RESPONSE, index, dumps(rs))
except Exception:
self.send(RT_RESPONSE | DT_PICKLE, index, pickle_dumps(rs, PICKLE_PROTOCOL))
except Exception as e:
log_except()
if dtype & ST_NO_RESULT or self.svr.stoped:
return
self.send(RT_EXCEPTION, index, pickle_dumps(e, PICKLE_PROTOCOL))
def _handle_response(self, parts):
dtype, index, argkw = parts
try:
rs = self._resps.pop(index)
if dtype & DT_PICKLE:
result = pickle_loads(argkw)
else:
result = loads(argkw)
rs.set(result)
except KeyError:
pass
def _handle_exception(self, parts):
RT_EXCEPTION, index, error = parts
error = pickle_loads(error)
try:
rs = self._resps.pop(index)
rs.set_exception(error)
except KeyError:
pass
def _handle(self, parts):
parts = (parts[0], ) + loads(parts[1]) if len(parts) ==2 else loads(parts[0])
rt = parts[0] & RT_MARK
if rt == RT_REQUEST:
self._handle_request(parts)
elif rt == RT_RESPONSE:
self._handle_response(parts)
elif rt == RT_EXCEPTION:
self._handle_exception(parts)
elif rt == RT_HEARTBEAT:
self._heart_time = time.time()
else:
raise ValueError, 'unknown data:%s' % msg
def heartbeat(self):
beat = RT_HEARTBEAT
btime = HEARTBEAT_TIME
while not self.stoped:
self.send(beat)
sleep(btime)
if self._heart_time + btime + 10 < time.time():
self.heart_timeout = True
break
self.stop()
#######remote call##############
def get_export(self, export_id):
""" get export obj by export_name """
if export_id == 0:
return self
try:
return self.svr._exports[export_id]()
except KeyError:
return None
def get_id_by_name(self, name):
return self.svr._names.get(name)
def get_proxy(self, export_id):
""" remote call: get export obj by id """
try:
return self._proxys[export_id]
except KeyError:
proxy = RpcProxy(self, export_id)
self._proxys[export_id] = proxy
return proxy
def get_proxy_by_name(self, name):
export_id = self.call(0, 'get_id_by_name', (name, ), None)
if export_id:
return self.get_proxy(export_id)
class RpcClient(RpcBase):
zmq_type = zmq.XREQ
def __init__(self, context=None, size=None):
RpcBase.__init__(self, context=None, size=None)
self.svc = RpcService(self, None)
def _handle(self, parts):
""" 单个数据包处理 """
self.svc._handle(parts)
def before_stop(self):
self.svc.stop()
def svc_stop(self, service):
if service == self.svc:
self.stop()
class RpcServer(RpcBase):
""" rpc services """
zmq_type = zmq.XREP
def __init__(self, context=None, size=None):
RpcBase.__init__(self, context=None, size=None)
self._services = {}
def before_stop(self):
for svc in self._services.values():
svc.stop()
def get_service(self, identity):
try:
return self._services[identity]
except KeyError:
self._services[identity] = RpcService(self, identity)
return self._services[identity]
def _handle(self, parts):
""" 单个数据包处理 """
sc = self.get_service(parts[0])
sc._handle(parts[1:])
def svc_stop(self, service):
self._services.pop(service.identity, None)
class RpcProxy(object):
def __init__(self, svc, export_id):
if 0:
self._svc = RpcService()
self._svc = svc
self._id = export_id
self._closes = []
def __getattr__(self, attribute):
def _func(*args, **kw):
no_result = kw.pop('_no_result', False)
timeout = kw.pop('_timeout', 60)
pickle = kw.pop('_pickle', False)
proxy = kw.pop('_proxy', False)
return self._svc.call(self._id, attribute, args, kw,
no_result=no_result, timeout=timeout, pickle=pickle, proxy=proxy)
return _func
def on_close(self):
for func in self._closes:
try:
func(self)
except StandardError:
log_except()
def sub_close(self, func):
if func not in self._closes:
self._closes.append(func)
def unsub_close(self, func):
if func in self._closes:
self._closes.remove(func)
def test():
class Test(AbsExport):
_rpc_name_ = 'test'
def t1(self):
return 't1'
def t2(self, *args, **kw):
return (args, kw)
def t3(self):
raise ValueError, 't3'
def t4(self, proxy, a, b):
rs = proxy.t1()
return rs, a, b
def tt(self):
raise ValueError, 't4'
return 'no_result'
global HEARTBEAT_TIME
endpoint = 'tcp://127.0.0.1:8081'
HEARTBEAT_TIME = 1
t1 = Test()
svr = RpcServer()
svr.bind(endpoint)
svr.register(t1)
svr.start()
def client_test(name):
client = RpcClient()
client.connect(endpoint)
client.start()
proxy = client.svc.get_proxy_by_name(t1._rpc_name_)
print name, proxy.t1()
print name, proxy.t2(1,2,3, a=1, b=2)
try:
proxy.t3()
except ValueError as e:
print name, e
print name, proxy.t4(t1, 'a', 'b', _proxy=True)
print name, proxy.tt(_no_result=True)
gevent.sleep(3)
## #print name, svr._exports
## print name, len(client.svc._proxys)
## del proxy
## #del t1
## import gc
## gc.collect()
## #print name, svr._exports
## print name, len(client.svc._proxys)
client.stop()
clients = []
clients.append(spawn(client_test, 'A'))
#clients.append(spawn(client_test, 'B'))
gevent.joinall(clients)
svr.stop()
def benchmark():
class Test(AbsExport):
_rpc_name_ = 'test'
def echo(self, msg):
return msg
endpoint = 'tcp://127.0.0.1:8081'
t1 = Test()
svr = RpcServer()
svr.bind(endpoint)
svr.register(t1)
svr.start()
client = RpcClient()
client.connect(endpoint)
client.start()
proxy = client.svc.get_proxy_by_name(t1._rpc_name_)
msg = '*' * 10
index = [0]
number = 1000
def test_echo():
rs = proxy.echo(msg)
assert rs == msg, 'echo error'
index[0] += 1
import timeit
t = timeit.timeit(test_echo, number=number)
print 'total:%s %s per/sec' % (t, number/float(t))
assert index[0] == number, 'index error'
client.stop()
svr.stop()
def main():
#test()
benchmark()
if __name__ == '__main__':
main()
#-------------------------------------------------------------------------------
#-------------------------------------------------------------------------------
#-------------------------------------------------------------------------------
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment