Skip to content

Instantly share code, notes, and snippets.

@righthandabacus
Created November 23, 2024 00:50
Show Gist options
  • Save righthandabacus/8a3c8a275e0c63172a5dc43adc0f5ce9 to your computer and use it in GitHub Desktop.
Save righthandabacus/8a3c8a275e0c63172a5dc43adc0f5ce9 to your computer and use it in GitHub Desktop.
Multitier architecture of webapp invoking slow functions asynchronously using Dash, Flask, diskcache, and concurrent.futures
import concurrent.futures
import datetime
import inspect
import logging
import os
import time
import threading
from typing import Callable, Any
import dash
import diskcache
import requests
from flask import Flask
from dash import Dash, html, dcc, Input, Output
# global dict to remember the defined job functions
# key = function name (unique)
# value = tuple of function object, timeout, and arg specs
# this dict is read only after all functions are loaded
defined_jobs = dict()
def jobfunction(timeout: float = 0.0) -> Callable:
"""Decorator to remember the job functions that can be called by worker
thread. Return value should be Python data types such that Flask can
serialize it into JSON.
Args:
timeout: The max age that the return value of the function call shall
still be valid
Modifies:
The global dict `defined_jobs` to remember the decorated function.
"""
def _deco(fn: Callable) -> Callable:
argspec = inspect.getfullargspec(fn)
defined_jobs[fn.__name__] = (fn, timeout, argspec)
return fn
return _deco
def clean_args(funcname: str, *args, **kwargs) -> tuple[tuple, dict]:
"""Clean up the args to a function so to normalize the ordering
Depends on the global dict `defined_jobs` to find the argspec of the target
function.
"""
logger = logging.getLogger("cleanargs")
# get the function object
_fn, _timeout, argspec = defined_jobs[funcname]
logger.debug("Input: %s, %s", args, kwargs)
# populate the args to call the function, prefer using positional args
fnargs = list(args)
for i in range(len(fnargs), len(argspec.args)):
argname = argspec.args[i]
if argname in kwargs:
fnargs.append(kwargs[argname])
fnargs = tuple(fnargs)
# populate the kwargs to call the function, in sorted order of keys
kwkeys = sorted([k for k in kwargs if k not in argspec.args])
fnkwargs = {k: kwargs[k] for k in kwkeys}
logger.debug("Output: %s, %s", fnargs, fnkwargs)
return fnargs, fnkwargs
def run_job(funcname: str, *args, **kwargs) -> None:
"""Call the job function by its name. Expects the args and kwargs are
normalized for consistency in the persistent cache. If the value is in the
persistent cache already, it is refreshed unconditionally.
Depends on the global dict `defined_jobs` to find the target function.
Example:
@jobfunction()
def foo(x, y, arg3):
pass
run_job(foo, arg1, arg2, arg3=42) -> foo(arg1, arg2, arg3=42)
Modifies:
The persistent cache to remember the timestamp and return value of the
function call. Working scoreboard to update upon the function call
completed.
Returns:
None. But the function is called synchronously and the result is
remembered in the persistent cache with the time.
"""
# get the function object
logger = logging.getLogger("run_job")
fn, _timeout, _argspec = defined_jobs[funcname]
jobspec = (funcname, args, kwargs)
logger.debug("Running %s", jobspec)
# remember the time that the job is ran in the persistent cache
value = fn(*args, **kwargs)
timestamp = time.time()
persistent[jobspec] = (timestamp, value)
working.delete(jobspec)
logger.debug("Updated cache for %s", jobspec)
def retrieve_result(funcname, *args, **kwargs) -> tuple[float, Any]:
"""Attempt to get the result of a function. If the value of the function
exists in the persistent cache and it is not expired, return the value. If
expired, return the staled timestamp and value and enqueue the job for
asynchronous refresh. If not in the persistent cache, return None and
enqueue the job for asynchronous run.
Depends on the global dict `defined_jobs` to resolve the target function
and create the jobspec.
Depends on the persistent cache to return latest value of function call.
Returns:
A tuple of two elements where the first is the timestamp of the value
and the second is the return value of the function call.
"""
logger = logging.getLogger("retrieve")
# Create jobspec
args, kwargs = clean_args(funcname, *args, **kwargs)
_fn, timeout, _argspec = defined_jobs[funcname]
jobspec = (funcname, args, kwargs)
# Upon cache miss, enqueue job and return None
if jobspec not in persistent:
logger.debug("Enqueue for cache miss: %s", jobspec)
enqueue(funcname, *args, **kwargs)
return None, None
# Upon cache hit, enqueue job if expired, and always return cached result
timestamp, value = persistent[jobspec]
now = time.time()
if timestamp + timeout < now:
logger.debug("Enqueue for timeout refresh: %s", jobspec)
enqueue(funcname, *args, **kwargs)
return timestamp, value
# job queue with deduplication
# enqueue: queue.push(jobspec)
# dequeue: _, jobspec = queue.pull() # ignore if jobspec is None
# check persistent cache: timestamp, value = persistent[(fn, args, kwargs)]
# mark job as queued/working: working[(fn, args, kwargs)] = 1
# mark job as done: del working[(fn, args, kwargs)]
CACHEBASEDIR = "./cache"
queue = diskcache.Cache(os.path.join(CACHEBASEDIR, "queue"))
working = diskcache.Cache(os.path.join(CACHEBASEDIR, "working"))
persistent = diskcache.Cache(os.path.join(CACHEBASEDIR, "persistent"))
# concurrent job executor
executor = concurrent.futures.ThreadPoolExecutor(max_workers=5)
def worker(sleeptime: float = 1.0):
"""Worker function for a background process/thread that consumes from the
queue and launch job
"""
logger = logging.getLogger("worker")
while True:
_, jobspec = queue.pull()
if jobspec is not None:
funcname, args, kwargs = jobspec
# run job in background from a thread/process pool
# equiv. to: run_job(fn, *args, **kwargs)
logger.debug("Submit to run: %s", jobspec)
executor.submit(run_job, funcname, *args, **kwargs)
time.sleep(sleeptime)
def enqueue(funcname: str, *args, **kwargs) -> None:
"""Enqueue a job on the condition that the pending job is not yet in queue
or not running.
"""
logger = logging.getLogger("enqueue")
args, kwargs = clean_args(funcname, *args, **kwargs)
jobspec = (funcname, args, kwargs)
# push job to queue only when not recorded in the scoreboard `working`
with working.transact():
if working.get(jobspec) is None:
logger.debug("Enqueue: %s", jobspec)
working[jobspec] = 1
queue.push(jobspec)
else:
logger.debug("Skipped enqueue due to duplicate: %s", jobspec)
#
# ---------------------
#
# Initialize the Dash app
TITLE = "Tiered dash app"
server = Flask(TITLE)
app = dash.Dash(
server=server,
title=TITLE,
compress=True,
)
# Layout of the app
app.layout = html.Div([
html.H1("Simple Button and Text Display"),
html.Button("Click Me!", id="button", n_clicks=0),
html.Div(id="output-container", style={"margin-top": "20px", "font-size": "20px"}),
html.Div(id="output-timestamp", style={"margin-top": "20px", "font-size": "20px"}),
])
# Callback to update the text display
@app.callback(
Output("output-container", "children"),
Output("output-timestamp", "children"),
Input("button", "n_clicks")
)
def update_output(n_clicks):
logger = logging.getLogger("dashcall")
if n_clicks == 0:
return "Button has not been clicked yet.", "(no timestamp)"
url = f"http://127.0.0.1:{port}/fetch/{n_clicks}"
logger.info("GET %s", url)
resp = requests.get(url)
data = resp.json()
text1 = f"Button clicked {n_clicks} times; value: {data['data']}"
local_tz = datetime.datetime.now().astimezone().tzinfo
if data["timestamp"] is None:
text2 = "(timestamp pending)"
else:
dt = datetime.datetime.fromtimestamp(data["timestamp"], local_tz)
text2 = str(dt)
return text1, text2
# slow function
@jobfunction(10)
def identity_func(*args, **kwargs) -> int:
time.sleep(3)
return working.incr("dummy", default=1)
# Flask endpoint - so that cache can be warmed up by script
@server.route("/fetch/<int:count>")
def get_output(count):
logger = logging.getLogger("flaskcall")
timestamp, data = retrieve_result("identity_func")
logger.info("%s - %s - %s", count, timestamp, data)
return {"timestamp": timestamp, "data": data}
# Run the app, using multithread
port = 8050
if __name__ == "__main__":
logger = logging.getLogger()
logger.setLevel(logging.DEBUG)
threading.Thread(target=worker, daemon=True).start()
# option: use `processes=3` instead of `threaded=True` for multiprocessing concurrency
app.run(host="0.0.0.0", port=port,
dev_tools_hot_reload=True, dev_tools_ui=True, debug=True,
threaded=True)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment