Created
May 29, 2012 15:28
-
-
Save jmkacz/2829062 to your computer and use it in GitHub Desktop.
Dvir's latest version of TProcessPoolServer
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# | |
# Licensed to the Apache Software Foundation (ASF) under one | |
# or more contributor license agreements. See the NOTICE file | |
# distributed with this work for additional information | |
# regarding copyright ownership. The ASF licenses this file | |
# to you under the Apache License, Version 2.0 (the | |
# "License"); you may not use this file except in compliance | |
# with the License. You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, | |
# software distributed under the License is distributed on an | |
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | |
# KIND, either express or implied. See the License for the | |
# specific language governing permissions and limitations | |
# under the License. | |
# | |
import logging, struct, socket | |
from multiprocessing import Process, Value, Condition, reduction,Lock | |
from TServer import TServer | |
from thrift.transport.TTransport import TTransportException | |
#import prctl | |
import signal | |
import os | |
import sys | |
import time | |
class TProcessPoolServer(TServer): | |
""" | |
Server with a fixed size pool of worker subprocesses which service requests. | |
Note that if you need shared state between the handlers - it's up to you! | |
Written by Dvir Volk, doat.com | |
""" | |
DEFAULT_NUM_WORKERS = 8 | |
CLIENT_TIMEOUT = 20 | |
def __init__(self, * args): | |
TServer.__init__(self, *args) | |
self.numWorkers = TProcessPoolServer.DEFAULT_NUM_WORKERS | |
self.workers = [] | |
self.isRunning = Value('b', False) | |
self.stopCondition = Condition() | |
self.postForkCallback = None | |
self.shutdownCallback = None | |
self.parentPid = os.getpid() | |
os.setpgid(self.parentPid, self.parentPid) | |
self._lock = Lock() | |
def setPostForkCallback(self, callback): | |
""" | |
Set a callback to be called in the workers AFTER they have forked. | |
This is useful for them to start threads, open sockets to databases, etc | |
""" | |
if not callable(callback): | |
raise TypeError("This is not a callback!") | |
self.postForkCallback = callback | |
def setShutdownCallback(self, callback): | |
""" | |
Set a callback to be called when we need to shut down the server | |
""" | |
if not callable(callback): | |
raise TypeError("This is not a callback!") | |
self.shutdownCallback = callback | |
def setNumWorkers(self, num): | |
"""Set the number of worker sub procs that should be created""" | |
self.numWorkers = num | |
def workerProcess(self, workerNum): | |
"""Loop around getting clients from the shared queue and process them.""" | |
self.workerNum = workerNum | |
logging.info("Worker starting! %s %s" % (workerNum, os.getpid())) | |
if self.postForkCallback: | |
try: | |
with self._lock: | |
self.postForkCallback() | |
#catch system exit while in post forking | |
except (KeyboardInterrupt, SystemExit): | |
logging.info("Worker closing! %s %s", workerNum, os.getpid()) | |
return 0 | |
except Exception, x: | |
logging.exception(x) | |
while self.isRunning.value == True: | |
try: | |
try: | |
client = self.serverTransport.accept() | |
except Exception, e: | |
logging.warn('socket timed out on accept!') | |
continue | |
self.serveClient(client) | |
except (KeyboardInterrupt, SystemExit): | |
logging.info("Worker closing! %s %s", workerNum, os.getpid()) | |
break | |
except Exception, x: | |
logging.exception(x) | |
logging.info("Shutting Down") | |
#Call the shutdown callback if necessary | |
if self.shutdownCallback: | |
try: | |
self.shutdownCallback() | |
except Exception, e: | |
logging.exception(e) | |
logging.info("Process %s exiting!" % os.getpid()) | |
def serveClient(self, client): | |
"""Process input/output from a client for as long as possible""" | |
itrans = self.inputTransportFactory.getTransport(client) | |
otrans = self.outputTransportFactory.getTransport(client) | |
iprot = self.inputProtocolFactory.getProtocol(itrans) | |
oprot = self.outputProtocolFactory.getProtocol(otrans) | |
try: | |
while True: | |
self.processor.process(iprot, oprot) | |
except TTransportException, tx: | |
pass | |
except (SystemExit, KeyboardInterrupt): | |
pass | |
except Exception, x: | |
logging.exception(x) | |
try: | |
itrans.close() | |
otrans.close() | |
except Exception, e: | |
logging.exception(e) | |
def serve(self): | |
"""Start a fixed number of worker threads and put client into a queue""" | |
#this is a shared state that can tell the workers to exit when set as false | |
self.isRunning.value = True | |
#first bind and listen to the port | |
self.serverTransport.listen() | |
#this is useful if you're constantly opening/closing connections | |
try: | |
self.serverTransport.handle.setsockopt(socket.SOL_SOCKET, socket.SO_LINGER, | |
struct.pack('ii', 1, 0)) | |
except Exception, e: | |
logging.error("could not set linger: %s" , e) | |
#fork the children | |
for i in range(self.numWorkers): | |
if not self.isRunning.value: | |
break | |
try: | |
w = Process(target=self.workerProcess, args = (i,), name = 'ServerWorker-%s' % i) | |
w.daemon = True | |
self.workers.append(w) | |
w.start() | |
#catch system exit while forking children | |
except (SystemExit, KeyboardInterrupt): | |
logging.warn("Got interrupt!") | |
self.isRunning.value = False | |
break | |
except Exception, x: | |
logging.exception(x) | |
logging.info("Exited forking loop!") | |
#wait until the condition is set by stop() | |
while self.isRunning.value: | |
self.stopCondition.acquire() | |
try: | |
self.stopCondition.wait() | |
self.stopCondition.release() | |
except (SystemExit, KeyboardInterrupt): | |
logging.warn("Got interrupt!") | |
break | |
except Exception, x: | |
logging.exception(x) | |
self.isRunning.value = False | |
def stop(self): | |
self.isRunning.value = False | |
self.stopCondition.acquire() | |
self.stopCondition.notify_all() | |
self.stopCondition.release() | |
logging.info("Stopped process pool server") | |
for proc in self.workers: | |
self._log("Joining worker %s. alive? %s" , proc, proc.is_alive()) | |
try: | |
proc.join(1.0) | |
#terminate the process anyways | |
if proc.is_alive(): | |
proc.terminate() | |
logging.info("Worker %s joined!", proc) | |
except Exception, e: | |
logging.exception(e) | |
proc.terminate() | |
self.serverTransport.close() | |
#send all the workers SIGTERM just in case | |
os.killpg(os.getpgid(self.parentPid), signal.SIGTERM) | |
os.waitpid(0, 0) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment