Skip to content

Instantly share code, notes, and snippets.

@tai
Created July 6, 2015 14:41
Show Gist options
  • Select an option

  • Save tai/ea2e9936e31932fe642c to your computer and use it in GitHub Desktop.

Select an option

Save tai/ea2e9936e31932fe642c to your computer and use it in GitHub Desktop.
Simple workqueue implementation to serialize request to pexpect-wrapped process.
#!/usr/bin/python3
# -*- coding: utf-8 -*-
#
# Workqueue test to serialize requests sent from multiple threads
# to protect pexpect-based backend.
#
import sys, os
import unittest
import glob
import random
import functools
import threading
import pexpect
from pexpect import EOF, TIMEOUT
from IPython import embed
import logging
log = logging.getLogger(__name__)
#logging.basicConfig(level=logging.DEBUG)
PS1 = "PROMPT> "
class FooService(object):
"""Sample pexpect-based service to be tested"""
def __init__(self):
sh = pexpect.spawn("/bin/sh")
sh.setecho(False)
sh.sendline("PS1='" + PS1 + "'")
sh.expect(PS1)
self.sh = sh
def mkdir(self, path):
sh = self.sh
sh.sendline("mkdir " + path)
return sh.expect(PS1) == 0
def chdir(self, path):
sh = self.sh
sh.sendline("chdir " + path)
return sh.expect(PS1) == 0
def pwd(self, exp=None):
sh = self.sh
sh.sendline("pwd")
i = sh.expect([PS1, EOF, TIMEOUT], timeout=3)
if i == 0:
ret = sh.before.decode('utf-8').strip()
if exp and ret != exp:
print("=== error: %d ===" % th)
print("exp: " + exp)
print("ret: " + ret)
print("buffer: " + str(sh.buffer))
print("before: " + str(sh.before))
print(" after: " + str(sh.after))
return ret
raise Exception("pwd")
def quit(self):
sh.close()
return True
class WorkQueue(threading.Thread):
"""A simple work queue for in-memory task serialization"""
def __init__(self):
super(self.__class__, self).__init__()
self.que = []
self.sem = threading.Semaphore(0)
self.quit = False
self.start()
def __enter__(self, *args, **kwargs):
return self
def __exit__(self, *args, **kwargs):
self.close()
def reset(self):
"""Force reset. Mostly for unittest"""
can_acquire = self.sem.acquire(blocking=False)
while can_acquire:
can_acquire = self.sem.acquire(blocking=False)
self.que.clear()
def close(self):
"""Shutdown workqueue"""
self.quit = True
self.sem.release()
def post(self, opt, func, *args, **kwds):
"""Returns a TX object for tracking posted request."""
fp = functools.partial(func, *args, **kwds)
fp.tx = lambda:0
fp.tx.ev = threading.Event()
self.que.append(fp)
self.sem.release()
return fp.tx
def call(self, opt, func, *args, **kwds):
"""Returns a (blocking) call result of a posted request"""
tx = self.post(opt, func, *args, **kwds)
if not tx.ev.wait(opt.get('timeout')):
raise Exception("timeout")
return tx.result
def run(self):
log.debug("wq: running")
# simple loop to process request in order
while not self.quit:
self.sem.acquire()
if self.quit: break
log.debug("request: fetching")
fp = self.que.pop(0)
try:
log.debug("request: running")
fp.tx.result = fp()
log.debug("request: returned")
except:
log.debug("request: aborted")
fp.tx.exc_info = sys.exc_info()
fp.tx.ev.set()
log.debug("wq: exiting")
class FooTest(unittest.TestCase):
def test_fgapp(self):
"""Tests service APIs in foreground"""
app = FooService()
for base, dirs, files in os.walk("/proc/fs"):
app.chdir(base)
pwd = app.pwd()
self.assertEqual(pwd, base)
def test_bgapp(self):
"""Tests service APIs in background, through workqueue"""
with WorkQueue() as wq:
app = FooService()
# sample to batch up multiple calls
def chdir_pwd(dst):
app.chdir(dst)
ret = app.pwd()
return ret
for base, dirs, files in os.walk("/proc/fs"):
ret = wq.call({}, chdir_pwd, base)
self.assertEqual(ret, base)
def test_bgapp_hard(self):
"""Tests workqueue harder"""
with WorkQueue() as wq:
app = FooService()
# sample to batch up multiple calls
def chdir_pwd(dst):
app.chdir(dst)
ret = app.pwd(dst)
return ret
# tests above batch for various dirs
def run_test(ra, ri):
dirs = [d for d,*rest in os.walk("/proc/fs")]
random.shuffle(dirs)
for dst in dirs:
try:
ret = wq.call({}, chdir_pwd, dst)
if ret != dst:
raise Exception("no match")
except:
ra[ri] += 1
# run chdir-pwd tests in parallel
threads = []
results = []
for i in range(0, 10):
th = threading.Thread(None, target=run_test, args=(results, i))
threads.append(th)
results.append(0)
for th in threads: th.start()
for th in threads: th.join()
# NOTE:
# - unittest API can be called in thread context,
# but has no effect on final result. So error in
# each thread needs to be counted manually.
self.assertEqual(sum(results), 0,
"MT call should not cause an error.")
if __name__ == '__main__':
unittest.main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment