Created
November 23, 2024 00:50
-
-
Save righthandabacus/8a3c8a275e0c63172a5dc43adc0f5ce9 to your computer and use it in GitHub Desktop.
Multitier architecture of webapp invoking slow functions asynchronously using Dash, Flask, diskcache, and concurrent.futures
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import concurrent.futures | |
import datetime | |
import inspect | |
import logging | |
import os | |
import time | |
import threading | |
from typing import Callable, Any | |
import dash | |
import diskcache | |
import requests | |
from flask import Flask | |
from dash import Dash, html, dcc, Input, Output | |
# global dict to remember the defined job functions | |
# key = function name (unique) | |
# value = tuple of function object, timeout, and arg specs | |
# this dict is read only after all functions are loaded | |
defined_jobs = dict() | |
def jobfunction(timeout: float = 0.0) -> Callable: | |
"""Decorator to remember the job functions that can be called by worker | |
thread. Return value should be Python data types such that Flask can | |
serialize it into JSON. | |
Args: | |
timeout: The max age that the return value of the function call shall | |
still be valid | |
Modifies: | |
The global dict `defined_jobs` to remember the decorated function. | |
""" | |
def _deco(fn: Callable) -> Callable: | |
argspec = inspect.getfullargspec(fn) | |
defined_jobs[fn.__name__] = (fn, timeout, argspec) | |
return fn | |
return _deco | |
def clean_args(funcname: str, *args, **kwargs) -> tuple[tuple, dict]: | |
"""Clean up the args to a function so to normalize the ordering | |
Depends on the global dict `defined_jobs` to find the argspec of the target | |
function. | |
""" | |
logger = logging.getLogger("cleanargs") | |
# get the function object | |
_fn, _timeout, argspec = defined_jobs[funcname] | |
logger.debug("Input: %s, %s", args, kwargs) | |
# populate the args to call the function, prefer using positional args | |
fnargs = list(args) | |
for i in range(len(fnargs), len(argspec.args)): | |
argname = argspec.args[i] | |
if argname in kwargs: | |
fnargs.append(kwargs[argname]) | |
fnargs = tuple(fnargs) | |
# populate the kwargs to call the function, in sorted order of keys | |
kwkeys = sorted([k for k in kwargs if k not in argspec.args]) | |
fnkwargs = {k: kwargs[k] for k in kwkeys} | |
logger.debug("Output: %s, %s", fnargs, fnkwargs) | |
return fnargs, fnkwargs | |
def run_job(funcname: str, *args, **kwargs) -> None: | |
"""Call the job function by its name. Expects the args and kwargs are | |
normalized for consistency in the persistent cache. If the value is in the | |
persistent cache already, it is refreshed unconditionally. | |
Depends on the global dict `defined_jobs` to find the target function. | |
Example: | |
@jobfunction() | |
def foo(x, y, arg3): | |
pass | |
run_job(foo, arg1, arg2, arg3=42) -> foo(arg1, arg2, arg3=42) | |
Modifies: | |
The persistent cache to remember the timestamp and return value of the | |
function call. Working scoreboard to update upon the function call | |
completed. | |
Returns: | |
None. But the function is called synchronously and the result is | |
remembered in the persistent cache with the time. | |
""" | |
# get the function object | |
logger = logging.getLogger("run_job") | |
fn, _timeout, _argspec = defined_jobs[funcname] | |
jobspec = (funcname, args, kwargs) | |
logger.debug("Running %s", jobspec) | |
# remember the time that the job is ran in the persistent cache | |
value = fn(*args, **kwargs) | |
timestamp = time.time() | |
persistent[jobspec] = (timestamp, value) | |
working.delete(jobspec) | |
logger.debug("Updated cache for %s", jobspec) | |
def retrieve_result(funcname, *args, **kwargs) -> tuple[float, Any]: | |
"""Attempt to get the result of a function. If the value of the function | |
exists in the persistent cache and it is not expired, return the value. If | |
expired, return the staled timestamp and value and enqueue the job for | |
asynchronous refresh. If not in the persistent cache, return None and | |
enqueue the job for asynchronous run. | |
Depends on the global dict `defined_jobs` to resolve the target function | |
and create the jobspec. | |
Depends on the persistent cache to return latest value of function call. | |
Returns: | |
A tuple of two elements where the first is the timestamp of the value | |
and the second is the return value of the function call. | |
""" | |
logger = logging.getLogger("retrieve") | |
# Create jobspec | |
args, kwargs = clean_args(funcname, *args, **kwargs) | |
_fn, timeout, _argspec = defined_jobs[funcname] | |
jobspec = (funcname, args, kwargs) | |
# Upon cache miss, enqueue job and return None | |
if jobspec not in persistent: | |
logger.debug("Enqueue for cache miss: %s", jobspec) | |
enqueue(funcname, *args, **kwargs) | |
return None, None | |
# Upon cache hit, enqueue job if expired, and always return cached result | |
timestamp, value = persistent[jobspec] | |
now = time.time() | |
if timestamp + timeout < now: | |
logger.debug("Enqueue for timeout refresh: %s", jobspec) | |
enqueue(funcname, *args, **kwargs) | |
return timestamp, value | |
# job queue with deduplication | |
# enqueue: queue.push(jobspec) | |
# dequeue: _, jobspec = queue.pull() # ignore if jobspec is None | |
# check persistent cache: timestamp, value = persistent[(fn, args, kwargs)] | |
# mark job as queued/working: working[(fn, args, kwargs)] = 1 | |
# mark job as done: del working[(fn, args, kwargs)] | |
CACHEBASEDIR = "./cache" | |
queue = diskcache.Cache(os.path.join(CACHEBASEDIR, "queue")) | |
working = diskcache.Cache(os.path.join(CACHEBASEDIR, "working")) | |
persistent = diskcache.Cache(os.path.join(CACHEBASEDIR, "persistent")) | |
# concurrent job executor | |
executor = concurrent.futures.ThreadPoolExecutor(max_workers=5) | |
def worker(sleeptime: float = 1.0): | |
"""Worker function for a background process/thread that consumes from the | |
queue and launch job | |
""" | |
logger = logging.getLogger("worker") | |
while True: | |
_, jobspec = queue.pull() | |
if jobspec is not None: | |
funcname, args, kwargs = jobspec | |
# run job in background from a thread/process pool | |
# equiv. to: run_job(fn, *args, **kwargs) | |
logger.debug("Submit to run: %s", jobspec) | |
executor.submit(run_job, funcname, *args, **kwargs) | |
time.sleep(sleeptime) | |
def enqueue(funcname: str, *args, **kwargs) -> None: | |
"""Enqueue a job on the condition that the pending job is not yet in queue | |
or not running. | |
""" | |
logger = logging.getLogger("enqueue") | |
args, kwargs = clean_args(funcname, *args, **kwargs) | |
jobspec = (funcname, args, kwargs) | |
# push job to queue only when not recorded in the scoreboard `working` | |
with working.transact(): | |
if working.get(jobspec) is None: | |
logger.debug("Enqueue: %s", jobspec) | |
working[jobspec] = 1 | |
queue.push(jobspec) | |
else: | |
logger.debug("Skipped enqueue due to duplicate: %s", jobspec) | |
# | |
# --------------------- | |
# | |
# Initialize the Dash app | |
TITLE = "Tiered dash app" | |
server = Flask(TITLE) | |
app = dash.Dash( | |
server=server, | |
title=TITLE, | |
compress=True, | |
) | |
# Layout of the app | |
app.layout = html.Div([ | |
html.H1("Simple Button and Text Display"), | |
html.Button("Click Me!", id="button", n_clicks=0), | |
html.Div(id="output-container", style={"margin-top": "20px", "font-size": "20px"}), | |
html.Div(id="output-timestamp", style={"margin-top": "20px", "font-size": "20px"}), | |
]) | |
# Callback to update the text display | |
@app.callback( | |
Output("output-container", "children"), | |
Output("output-timestamp", "children"), | |
Input("button", "n_clicks") | |
) | |
def update_output(n_clicks): | |
logger = logging.getLogger("dashcall") | |
if n_clicks == 0: | |
return "Button has not been clicked yet.", "(no timestamp)" | |
url = f"http://127.0.0.1:{port}/fetch/{n_clicks}" | |
logger.info("GET %s", url) | |
resp = requests.get(url) | |
data = resp.json() | |
text1 = f"Button clicked {n_clicks} times; value: {data['data']}" | |
local_tz = datetime.datetime.now().astimezone().tzinfo | |
if data["timestamp"] is None: | |
text2 = "(timestamp pending)" | |
else: | |
dt = datetime.datetime.fromtimestamp(data["timestamp"], local_tz) | |
text2 = str(dt) | |
return text1, text2 | |
# slow function | |
@jobfunction(10) | |
def identity_func(*args, **kwargs) -> int: | |
time.sleep(3) | |
return working.incr("dummy", default=1) | |
# Flask endpoint - so that cache can be warmed up by script | |
@server.route("/fetch/<int:count>") | |
def get_output(count): | |
logger = logging.getLogger("flaskcall") | |
timestamp, data = retrieve_result("identity_func") | |
logger.info("%s - %s - %s", count, timestamp, data) | |
return {"timestamp": timestamp, "data": data} | |
# Run the app, using multithread | |
port = 8050 | |
if __name__ == "__main__": | |
logger = logging.getLogger() | |
logger.setLevel(logging.DEBUG) | |
threading.Thread(target=worker, daemon=True).start() | |
# option: use `processes=3` instead of `threaded=True` for multiprocessing concurrency | |
app.run(host="0.0.0.0", port=port, | |
dev_tools_hot_reload=True, dev_tools_ui=True, debug=True, | |
threaded=True) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment