Created
May 22, 2023 16:26
-
-
Save adivekar-utexas/d9baa02c4f2994d7c0127b9c4228eda9 to your computer and use it in GitHub Desktop.
Python Concurrency utils
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""A collection of utilities to augment the Python language:""" | |
from typing import * | |
import time, traceback, random, sys | |
import math, gc | |
from datetime import datetime | |
from math import inf | |
import numpy as np | |
from threading import Semaphore | |
import multiprocessing as mp | |
from concurrent.futures._base import Future | |
from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor, wait as wait_future | |
from concurrent.futures.thread import BrokenThreadPool | |
from concurrent.futures.process import BrokenProcessPool | |
import ray | |
from ray.util.dask import RayDaskCallback | |
from pydantic import validate_arguments, conint, confloat | |
def concurrent(max_active_threads: int = 10, max_calls_per_second: float = inf): | |
""" | |
Decorator which runs function calls concurrently via multithreading. | |
When decorating an IO-bound function with @concurrent(MAX_THREADS), and then invoking the function | |
N times in a loop, it will run min(MAX_THREADS, N) invocations of the function concurrently. | |
For example, if your function calls another service, and you must invoke the function N times, decorating with | |
@concurrent(3) ensures that you only have 3 concurrent function-calls at a time, meaning you only make | |
3 concurrent requests at a time. This reduces the number of connections you are making to the downstream service. | |
As this uses multi-threading and not multi-processing, it is suitable for IO-heavy functions, not CPU-heavy. | |
Each call to the decorated function returns a future. Calling .result() on that future will return the value. | |
Generally, you should call the decorated function N times in a loop, and store the futures in a list/dict. Then, | |
call .result() on all the futures, saving the results in a new list/dict. Each .result() call is synchronous, so the | |
order of items is maintained between the lists. When doing this, at most min(MAX_THREADS, N) function calls will be | |
running concurrently. | |
Note that if the function calls throws an exception, then calling .result() will raise the exception in the | |
orchestrating code. If multiple function calls raise an exception, the one on which .result() was called first will | |
throw the exception to the orchestrating code. You should add try-catch logic inside your decorated function to | |
ensure exceptions are handled. | |
Note that decorated function `a` can call another decorated function `b` without issues; it is upto the function A | |
to determine whether to call .result() on the futures it gets from `b`, or return the future to its own invoker. | |
`max_calls_per_second` controls the rate at which we can call the function. This is particularly important for | |
functions which execute quickly: e.g. suppose the decorated function calls a downstream service, and we allow a | |
maximum concurrency of 5. If each function call takes 100ms, then we end up making 1000/100*5 = 50 calls to the | |
downstream service each second. We thus should pass `max_calls_per_second` to restrict this to a smaller value. | |
:param max_active_threads: the max number of threads which can be running the function at one time. This is thus | |
them max concurrency factor. | |
:param max_calls_per_second: controls the rate at which we can call the function. | |
:return: N/A, this is a decorator. | |
""" | |
## Refs: | |
## 1. ThreadPoolExecutor: docs.python.org/3/library/concurrent.futures.html#concurrent.futures.Executor.submit | |
## 2. Decorators: www.datacamp.com/community/tutorials/decorators-python | |
## 3. Semaphores: www.geeksforgeeks.org/synchronization-by-using-semaphore-in-python/ | |
## 4. Overall code: https://gist.github.com/gregburek/1441055#gistcomment-1294264 | |
def decorator(function): | |
## Each decorated function gets its own executor and semaphore. These are defined at the function-level, so | |
## if you write two decorated functions `def say_hi` and `def say_bye`, they each gets a separate executor and | |
## semaphore. Then, if you invoke `say_hi` 30 times and `say_bye` 20 times, all 30 calls to say_hi will use the | |
## same executor and semaphore, and all 20 `say_bye` will use a different executor and semaphore. The value of | |
## `max_active_threads` will determine how many function calls actually run concurrently, e.g. if say_hi has | |
## max_active_threads=5, then the 30 calls will run 5 at a time (this is enforced by the semaphore). | |
executor = ThreadPoolExecutor(max_workers=max_active_threads) | |
semaphore = Semaphore(max_active_threads) | |
## The minimum time between invocations. | |
min_time_interval_between_calls = 1 / max_calls_per_second | |
## This only stores a single value, but it must be a list (mutable) for Python's function scoping to work. | |
time_last_called = [0.0] | |
def wrapper(*args, **kwargs) -> Future: | |
semaphore.acquire() | |
time_elapsed_since_last_called = time.time() - time_last_called[0] | |
time_to_wait_before_next_call = max(0.0, min_time_interval_between_calls - time_elapsed_since_last_called) | |
time.sleep(time_to_wait_before_next_call) | |
def run_function(*args, **kwargs): | |
try: | |
result = function(*args, **kwargs) | |
finally: | |
semaphore.release() ## If the function call throws an exception, release the semaphore. | |
return result | |
time_last_called[0] = time.time() | |
return executor.submit(run_function, *args, **kwargs) ## return a future | |
return wrapper | |
return decorator | |
_GLOBAL_THREAD_POOL_EXECUTOR: ThreadPoolExecutor = ThreadPoolExecutor(max_workers=127) | |
def run_concurrent( | |
fn, | |
*args, | |
executor: Optional[ThreadPoolExecutor] = None, | |
**kwargs, | |
): | |
global _GLOBAL_THREAD_POOL_EXECUTOR | |
if executor is None: | |
executor: ThreadPoolExecutor = _GLOBAL_THREAD_POOL_EXECUTOR | |
try: | |
return executor.submit(fn, *args, **kwargs) ## return a future | |
except BrokenThreadPool as e: | |
if executor is _GLOBAL_THREAD_POOL_EXECUTOR: | |
executor = ThreadPoolExecutor(max_workers=_GLOBAL_THREAD_POOL_EXECUTOR._max_workers) | |
del _GLOBAL_THREAD_POOL_EXECUTOR | |
_GLOBAL_THREAD_POOL_EXECUTOR = executor | |
return executor.submit(fn, *args, **kwargs) ## return a future | |
raise e | |
_GLOBAL_PROCESS_POOL_EXECUTOR: ProcessPoolExecutor = ProcessPoolExecutor( | |
max_workers=max(1, min(32, mp.cpu_count() - 1)) | |
) | |
def run_parallel( | |
fn, | |
*args, | |
executor: Optional[ProcessPoolExecutor] = None, | |
**kwargs, | |
): | |
global _GLOBAL_PROCESS_POOL_EXECUTOR | |
if executor is None: | |
executor: ProcessPoolExecutor = _GLOBAL_PROCESS_POOL_EXECUTOR | |
try: | |
return executor.submit(fn, *args, **kwargs) ## return a future | |
except BrokenProcessPool as e: | |
if executor is _GLOBAL_PROCESS_POOL_EXECUTOR: | |
executor = ProcessPoolExecutor(max_workers=_GLOBAL_PROCESS_POOL_EXECUTOR._max_workers) | |
del _GLOBAL_PROCESS_POOL_EXECUTOR | |
_GLOBAL_PROCESS_POOL_EXECUTOR = executor | |
return executor.submit(fn, *args, **kwargs) ## return a future | |
raise e | |
@ray.remote(num_cpus=1) | |
def __run_parallel_ray_single_cpu(fn, *args, **kwargs): | |
return fn(*args, **kwargs) | |
def run_parallel_ray(fn, *args, **kwargs): | |
return __run_parallel_ray_single_cpu.remote(fn, *args, **kwargs) | |
def get_result(x): | |
if isinstance(x, Future): | |
return x.result() | |
if isinstance(x, ray.ObjectRef): | |
return ray.get(x) | |
return x | |
def is_done(x) -> bool: | |
if isinstance(x, Future): | |
return x.done() | |
if isinstance(x, ray.ObjectRef): | |
## Ref: docs.ray.io/en/latest/ray-core/tasks.html#waiting-for-partial-results | |
done, not_done = ray.wait([x], timeout=0) ## Immediately check if done. | |
return len(done) > 0 and len(not_done) == 0 | |
return True | |
def accumulate(futures: Union[Tuple, List, Set, Dict, Any]) -> Union[Tuple, List, Set, Dict, Any]: | |
"""Join operation on a single future or a collection of futures.""" | |
if isinstance(futures, list): | |
return [get_result(future) for future in futures] | |
elif isinstance(futures, tuple): | |
return tuple([get_result(future) for future in futures]) | |
elif isinstance(futures, set): | |
return set([get_result(future) for future in futures]) | |
elif isinstance(futures, dict): | |
return {k: get_result(v) for k, v in futures.items()} | |
else: | |
return get_result(futures) | |
def wait_if_future(x): | |
if isinstance(x, Future): | |
wait_future([x]) | |
elif isinstance(x, ray.ObjectRef): | |
ray.wait([x]) | |
def wait(futures: Union[Tuple, List, Set, Dict, Any]) -> NoReturn: | |
"""Join operation on a single future or a collection of futures.""" | |
if isinstance(futures, (list, tuple, set)): | |
[wait_if_future(future) for future in futures] | |
elif isinstance(futures, dict): | |
[wait_if_future(v) for k, v in futures.items()] | |
else: | |
wait_if_future(futures) | |
@validate_arguments | |
def retry(fn, *args, retries: conint(ge=1) = 5, wait: confloat(gt=0.0) = 10.0, jitter: confloat(gt=0.0) = 0.5, | |
**kwargs): | |
""" | |
Retries a function call a certain number of times, waiting between calls (with a jitter in the wait period). | |
:param fn: the function to call. | |
:param retries: max number of times to try. | |
:param wait: average wait period between retries | |
:param jitter: limit of jitter (+-). E.g. jitter=0.1 means we will wait for a random time period in the range | |
(0.9 * wait, 1.1 * wait) seconds. | |
:param kwargs: keyword arguments forwarded to the function. | |
:return: the function's return value if any call succeeds. | |
:raise: RuntimeError if all `retries` calls fail. | |
""" | |
wait: float = float(wait) | |
latest_exception = None | |
for retry_num in range(retries): | |
try: | |
return fn(*args, **kwargs) | |
except Exception as e: | |
latest_exception = traceback.format_exc() | |
print(f'Function call failed with the following exception:\n{latest_exception}') | |
if retry_num < (retries - 1): | |
print(f'Retrying {retries - (retry_num + 1)} more times...\n') | |
time.sleep(np.random.uniform(wait - wait * jitter, wait + wait * jitter)) | |
raise RuntimeError(f'Function call failed {retries} times.\nLatest exception:\n{latest_exception}\n') | |
def daemon(wait: float, exit_on_error: bool = False, sentinel: Optional[List] = None, **kwargs): | |
""" | |
A decorator which runs a function as a daemon process in a background thread. | |
You do not need to invoke this function directly: simply decorating the daemon function will start running it | |
in the background. | |
Example using class method: your daemon should be marked with @staticmethod. Example: | |
class Printer: | |
DATA_LIST = [] | |
@staticmethod | |
@daemon(wait=3, mylist=DATA_LIST) | |
def printer_daemon(mylist): | |
if len(mylist) > 0: | |
print(f'Contents of list: {mylist}', flush=True) | |
Example using sentinel: | |
run_sentinel = [True] | |
@daemon(wait=1, sentinel=run_sentinel) | |
def run(): | |
print('Running', flush=True) | |
time.sleep(3) ## Prints "Running" 3 times. | |
run_sentinel.pop() ## Stops "Running" from printing any more. | |
:param wait: the wait time in seconds between invocations to the @daemon decorated function. | |
:param exit_on_error: whether to stop the daemon if an error is raised. | |
:sentinel: can be used to stop the executor. When not passed, the daemon runs forever. When passed, `sentinel` must | |
be a list with exactly one element (it can be anything). To stop the daemon, run "sentinel.pop()". It is | |
important to pass a list (not a tuple), since lists are mutable, and thus the same exact object is used by | |
both the executor and by the caller. | |
:param kwargs: list of arguments passed to the decorator, which are forwarded to the decorated function as kwargs. | |
These values will never change for the life of the daemon. However, if you pass references to mutables such as | |
lists, dicts, objects etc to the decorator and use them in the daemon function, you can run certain tasks at a | |
regular cadence on fresh data. | |
:return: None | |
""" | |
## Refs on how decorators work: | |
## 1. https://www.datacamp.com/community/tutorials/decorators-python | |
def decorator(function): | |
## Each decorated function gets its own executor. These are defined at the function-level, so | |
## if you write two decorated functions `def say_hi` and `def say_bye`, they each gets a separate | |
## executor. The executor for `say_hi` will call `say_hi` repeatedly, and the executor for `say_bye` will call | |
## `say_bye` repeatedly; they will not interact. | |
executor = ThreadPoolExecutor(max_workers=1) | |
def run_function_forever(sentinel): | |
while sentinel is None or len(sentinel) > 0: | |
start = time.perf_counter() | |
try: | |
function(**kwargs) | |
except Exception as e: | |
print(traceback.format_exc()) | |
if exit_on_error: | |
raise e | |
end = time.perf_counter() | |
time_to_wait: float = max(0.0, wait - (end - start)) | |
time.sleep(time_to_wait) | |
del executor ## Cleans up the daemon after it finishes running. | |
if sentinel is not None: | |
if not isinstance(sentinel, list) or len(sentinel) != 1: | |
raise ValueError(f'When passing `sentinel`, it must be a list with exactly one item.') | |
completed: Future = executor.submit(run_function_forever, sentinel=sentinel) | |
## The wrapper here should do nothing, since you cannot call the daemon explicitly. | |
def wrapper(*args, **kwargs): | |
raise RuntimeError('Cannot call daemon function explicitly') | |
return wrapper | |
return decorator | |
## Dict of daemon ids to their sentinels | |
_DAEMONS: Dict[str, List[bool]] = {} | |
def start_daemon(fn, wait: float, id: Optional[str] = None, **kwargs) -> str: | |
assert isinstance(wait, (int, float)) and wait >= 0.0 | |
if id is None: | |
dt: datetime = datetime.now() | |
dt: datetime = dt.replace(tzinfo=dt.astimezone().tzinfo) | |
if dt.tzinfo is not None: | |
id: str = dt.strftime('%Y-%m-%d %H:%M:%S.%f UTC%z').strip() | |
else: | |
id: str = dt.strftime('%Y-%m-%d %H:%M:%S.%f').strip() | |
assert isinstance(id, str) and len(id) > 0 | |
assert id not in _DAEMONS, f'Daemon with id "{id}" already exists.' | |
daemon_sentinel: List[bool] = [True] | |
@daemon(wait=wait, sentinel=daemon_sentinel) | |
def run(): | |
fn(**kwargs) | |
_DAEMONS[id] = daemon_sentinel | |
return id | |
def stop_daemon(id: str) -> bool: | |
assert isinstance(id, str) and len(id) > 0 | |
daemon_sentinel: List[bool] = _DAEMONS.pop(id, [False]) | |
assert len(daemon_sentinel) == 1 | |
return daemon_sentinel.pop() | |
## Ref: https://docs.ray.io/en/latest/data/dask-on-ray.html#callbacks | |
class RayDaskPersistWaitCallback(RayDaskCallback): | |
## Callback to wait for computation to complete when .persist() is called with block=True | |
def _ray_postsubmit_all(self, object_refs, dsk): | |
wait(object_refs) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment