Last active
July 24, 2024 15:59
-
-
Save schaumb/d557dabf0beced7dfaa1be7acc09b1e4 to your computer and use it in GitHub Desktop.
Streamlit - Support custom HTTP requests
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import functools | |
import gc | |
import weakref | |
from typing import Optional, Callable, Union | |
from weakref import WeakSet | |
from streamlit import config | |
from streamlit.runtime import Runtime | |
from streamlit.runtime.scriptrunner import get_script_run_ctx, add_script_run_ctx | |
from streamlit.web.server.server_util import make_url_path_regex | |
from tornado import httputil | |
from tornado.httputil import HTTPServerRequest, ResponseStartLine, HTTPHeaders | |
from tornado.routing import Rule, AnyMatches, ReversibleRuleRouter | |
from tornado.web import Application | |
class _RouteRegister: | |
@staticmethod | |
def handler(self: HTTPServerRequest, path_args, path_kwargs, func: Callable, | |
ctx: weakref.ref): | |
old = get_script_run_ctx() | |
add_script_run_ctx(ctx=ctx()) | |
res = func(path_args=path_args, | |
path_kwargs=path_kwargs, | |
method=self.method, | |
body=self.body, | |
arguments=self.arguments) | |
add_script_run_ctx(ctx=old) | |
response_code: int = 200 | |
response_body: Optional[Union[bytes, str]] = None | |
headers = HTTPHeaders() | |
if not isinstance(res, tuple): | |
if isinstance(res, int): | |
response_code = res | |
elif isinstance(res, bytes) or isinstance(res, str): | |
response_body = res | |
elif res is not None: | |
raise TypeError('Unknown return type from handler.') | |
else: | |
if len(res) == 2: | |
response_code, response_body = res | |
elif len(res) == 3: | |
response_code, response_body, header = res | |
for header_key, header_val in header.items(): | |
headers.add(header_key, header_val) | |
else: | |
raise TypeError('Unknown return type from handler.') | |
if isinstance(response_body, str): | |
response_body = response_body.encode() | |
if response_body is not None: | |
headers.add("Content-Length", str(len(response_body))) | |
self.connection.write_headers( | |
ResponseStartLine(self.version, response_code, httputil.responses.get(response_code, "Unknown")), | |
headers, | |
) | |
self.connection.write(response_body) | |
self.connection.finish() | |
@classmethod | |
def instance(cls) -> '_RouteRegister': | |
inst: Runtime = Runtime.instance() | |
res: Optional[_RouteRegister] = getattr(inst, '_streamlit_route_register', None) | |
if res is None: | |
app: Application = next(iter((k for k in gc.get_referrers(Application) if isinstance(k, Application)))) | |
res = _RouteRegister() | |
app.add_handlers(".*", [Rule(AnyMatches(), res._the_rules)]) | |
setattr(inst, '_streamlit_route_register', res) | |
return res | |
def __init__(self): | |
self._the_rules: ReversibleRuleRouter = ReversibleRuleRouter([]) | |
self._the_rules.rules = WeakSet() | |
setattr(self._the_rules.rules, 'append', getattr(self._the_rules.rules, 'add')) | |
self._deregists = {} | |
@staticmethod | |
def _get_full_path(path: str, globally: bool, trailing_slash: bool, session_id: str) -> str: | |
return make_url_path_regex(config.get_option("server.baseUrlPath"), | |
*(() if globally else (session_id,)), | |
path, | |
trailing_slash=trailing_slash) | |
def regist_or_replace(self, path: str, globally: bool, trailing_slash: bool, f: Callable): | |
ctx = get_script_run_ctx() | |
session_id = ctx.session_id | |
full_path = _RouteRegister._get_full_path(path, globally, trailing_slash, session_id) | |
self._the_rules.add_rules( | |
[(full_path, functools.partial(_RouteRegister.handler, func=f, ctx=weakref.ref(ctx)), {}, full_path)] | |
) | |
def dereg(missing_client): | |
for obj in self._deregists.pop(session_id, None)[0]: | |
obj() | |
client = Runtime.instance().get_client(session_id) | |
self._deregists.setdefault(session_id, (set(), weakref.proxy(client, dereg)))[0].add( | |
functools.partial(self.clear_function, path, globally, trailing_slash, session_id) | |
) | |
def clear_function(self, path: str, globally: bool, trailing_slash: bool, session_id: str): | |
full_path = _RouteRegister._get_full_path(path, globally, trailing_slash, session_id) | |
self._the_rules.named_rules.pop(full_path, None) | |
def clear_all(self): | |
self._the_rules.named_rules.clear() | |
self._the_rules.rules.clear() | |
def st_route(path: str, globally: bool = False, trailing_slash: bool = True): | |
if not isinstance(path, str) or not path: | |
raise AttributeError('First argument must be a not empty path') | |
def wrap(f: Callable): | |
rr: _RouteRegister = _RouteRegister.instance() | |
rr.regist_or_replace(path, globally, trailing_slash, f) | |
setattr(f, 'clear', lambda: rr.clear_function(path, globally, trailing_slash, get_script_run_ctx().session_id)) | |
return f | |
return wrap | |
setattr(st_route, 'clear', lambda: _RouteRegister.instance().clear_all()) | |
''' | |
Basic usage: | |
``` | |
@st_route(path='get_name/(.*)') | |
def any_name( | |
path_args: List[str], | |
path_kwargs: Dict[str, Any], | |
method: str, | |
body: bytes, | |
arguments: Dict[str, Any] | |
) -> Union[int, bytes, str, Tuple[int, Union[bytes, str]], Tuple[int, Union[bytes, str], Dict[str, Any]]]: | |
""" | |
path_args: path regex unnamed groups | |
path_kwargs: path regex named groups | |
method: HTTP method | |
body: the request body in bytes | |
arguments: The query and body arguments | |
returns with any of the followings: | |
int: HTTP response code | |
bytes/str: HTTP 200 with body. str encoded with python default | |
Tuple[int, bytes/str]: HTTP response code with body | |
Tuple[int, bytes/str, Dict[str, Any]]: HTTP response code with body and additional headers | |
If you don't need any of the arguments, just use **kwargs. | |
""" | |
return "Hello " + path_args[0].decode() | |
get_script_run_ctx().session_id | |
``` | |
It exposes the `localhost:8501/<session_id>/get_name/anyone` and returns with HTTP200 and "Hello anyone". | |
So frontend can communicate with running code and can get already calculated values. | |
The functions wrapped with `st_route` can use `st.` commands but only until the script is run. | |
If you set `st_route` `globally` argument to `True`, the <session_id> is not added to the path. | |
''' | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Inspired by st_route.py, I made the following to enable Socket.IO in streamlit.
st_socketio.py