Skip to content

Instantly share code, notes, and snippets.

@alukach
Created May 11, 2017 20:52
Show Gist options
  • Save alukach/d67443af769ddecf5153f4e7eca2fd0a to your computer and use it in GitHub Desktop.
Save alukach/d67443af769ddecf5153f4e7eca2fd0a to your computer and use it in GitHub Desktop.
Threading Context Manager
from multiprocessing import cpu_count
import threading
import queue
import logging
logger = logging.getLogger(__name__)
class ThreadQueue(object):
def __init__(self, thread_multiplier=2):
self.q = queue.Queue()
self.num_threads = (cpu_count() * thread_multiplier) or 1
self.killswitch = threading.Event()
def __enter__(self):
logger.debug("Entering ThreadQueue context")
for i in range(self.num_threads):
t = threading.Thread(
target=self.worker, args=("Thread-{}".format(i),))
t.start()
return self.q
def __exit__(self, exc_type, exc_value, traceback):
self.killswitch.set()
logger.debug("Exiting ThreadQueue context")
def worker(self, name):
logger.debug("Starting thread {}".format(name))
while not self.killswitch.is_set():
try:
item = self.q.get(timeout=1)
print("got {}".format(item))
except queue.Empty:
continue
err = ("Improperly formatted arguments. Expected tuple(func, "
"args, kwargs). Got {}")
assert item, err.format(item)
args = ()
kwargs = {}
if not hasattr(item, '__len__'):
func = item
else:
func = item[0]
if len(item) > 1:
args = item[1]
assert hasattr(args, '__len__')
if len(item) > 2:
kwargs = item[2]
assert isinstance(kwargs, dict)
if len(item) > 3:
raise AssertionError(err.format(item))
func(*args, **kwargs)
self.q.task_done()
logger.debug("Stopping thread {}".format(name))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment