-
-
Save skasero/c8d1a6e875003b19351b353c3c2a30f4 to your computer and use it in GitHub Desktop.
Streamlit - Support custom HTTP requests
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
########### | |
## Original taken from https://gist.github.com/schaumb/d557dabf0beced7dfaa1be7acc09b1e4 | |
## Modified and improved for better reliability | |
########### | |
import gc | |
import weakref | |
import logging | |
from typing import Optional, Callable, Union, Dict, Set | |
from collections import defaultdict | |
from streamlit.runtime import Runtime | |
from streamlit.runtime.scriptrunner import get_script_run_ctx, add_script_run_ctx | |
from streamlit.web.server.server_util import make_url_path_regex | |
from tornado import httputil | |
from tornado.httputil import HTTPServerRequest, ResponseStartLine, HTTPHeaders | |
from tornado.routing import Rule, AnyMatches, ReversibleRuleRouter | |
from tornado.web import Application | |
logger = logging.getLogger(__name__) | |
class _RouteRegister: | |
_instance: Optional['_RouteRegister'] = None | |
_application: Optional[Application] = None | |
@staticmethod | |
def handler(self: HTTPServerRequest, path_args, path_kwargs, func: Callable, | |
ctx: weakref.ref): | |
headers = HTTPHeaders() | |
old = get_script_run_ctx() | |
try: | |
current_ctx = ctx() if ctx else None | |
if current_ctx: | |
add_script_run_ctx(ctx=current_ctx) | |
res = func(path_args=path_args, | |
path_kwargs=path_kwargs, | |
method=self.method, | |
body=self.body, | |
arguments=self.arguments) | |
except Exception as e: | |
response_code = 500 | |
response_body = f"Internal Server Error: {str(e)}" | |
logger.error(f"Route handler error: {e}") | |
headers.add("Content-Type", "text/plain") | |
self.connection.write(response_body) | |
self.connection.finish() | |
return | |
finally: | |
add_script_run_ctx(ctx=old) | |
response_code: int = 200 | |
response_body: Optional[Union[bytes, str]] = None | |
if not isinstance(res, tuple): | |
if isinstance(res, int): | |
response_code = res | |
elif isinstance(res, bytes) or isinstance(res, str): | |
response_body = res | |
elif res is not None: | |
raise TypeError('Unknown return type from handler.') | |
else: | |
if len(res) == 2: | |
response_code, response_body = res | |
elif len(res) == 3: | |
response_code, response_body, header = res | |
for header_key, header_val in header.items(): | |
headers.add(header_key, header_val) | |
else: | |
raise TypeError('Unknown return type from handler.') | |
if not isinstance(response_body, (bytes, str)): | |
response_body = str(response_body) | |
if isinstance(response_body, str): | |
response_body = response_body.encode("utf-8") | |
self.connection.write_headers( | |
ResponseStartLine(self.version, response_code, httputil.responses.get(response_code, "Unknown")), | |
headers, | |
) | |
self.connection.write(response_body) | |
self.connection.finish() | |
@classmethod | |
def _find_application(cls) -> Optional[Application]: | |
"""Find Tornado Application instance more reliably""" | |
if cls._application and cls._application.is_running_from_reloader is not None: | |
return cls._application | |
# Try multiple methods to find the application | |
try: | |
# Method 1: Check Runtime instance | |
runtime = Runtime.instance() | |
if hasattr(runtime, '_tornado_app'): | |
cls._application = runtime._tornado_app | |
return cls._application | |
except Exception: | |
pass | |
try: | |
# Method 2: GC search as fallback | |
for obj in gc.get_objects(): | |
if isinstance(obj, Application): | |
cls._application = obj | |
return cls._application | |
except Exception: | |
pass | |
logger.warning("Could not find Tornado Application instance") | |
return None | |
@classmethod | |
def instance(cls) -> '_RouteRegister': | |
if cls._instance is None: | |
inst: Runtime = Runtime.instance() | |
res: Optional[_RouteRegister] = getattr(inst, '_streamlit_route_register', None) | |
if res is None: | |
app = cls._find_application() | |
if app is None: | |
raise RuntimeError("Could not find Tornado Application instance") | |
res = _RouteRegister() | |
app.add_handlers(".*", [Rule(AnyMatches(), res._the_rules)]) | |
setattr(inst, '_streamlit_route_register', res) | |
cls._instance = res | |
logger.info("RouteRegister instance created and registered") | |
else: | |
cls._instance = res | |
return cls._instance | |
def __init__(self): | |
self._the_rules: ReversibleRuleRouter = ReversibleRuleRouter([]) | |
self._the_rules.rules = [] | |
self._registered_routes: Dict[str, Callable] = {} | |
self._route_contexts: Dict[str, weakref.ref] = {} | |
self._deregists: Dict[str, Set] = defaultdict(set) | |
logger.debug("RouteRegister initialized") | |
@staticmethod | |
def _get_full_path(path: str, globally: bool, trailing_slash: bool, session_id: str) -> str: | |
return make_url_path_regex(*(() if globally else (session_id,)), | |
path, | |
trailing_slash=trailing_slash) | |
def regist_or_replace(self, path: str, globally: bool, | |
trailing_slash: bool, f: Callable): | |
try: | |
ctx = get_script_run_ctx() | |
session_id = ctx.session_id if ctx else "global" | |
except Exception: | |
logger.warning("Could not get script context, using global session") | |
session_id = "global" | |
ctx = None | |
full_path = _RouteRegister._get_full_path(path, globally, trailing_slash, session_id) | |
def make_tornado_wrapper(func, ctx_ref): | |
def tornado_handler(request, *regex_args, **regex_kwargs): | |
path_args = list(regex_args) | |
path_kwargs = regex_kwargs | |
return _RouteRegister.handler( | |
request, | |
path_args=path_args, | |
path_kwargs=path_kwargs, | |
func=func, | |
ctx=ctx_ref | |
) | |
return tornado_handler | |
# Store strong reference to prevent garbage collection | |
ctx_ref = weakref.ref(ctx) if ctx else None | |
wrapper = make_tornado_wrapper(f, ctx_ref) | |
# Keep strong references | |
self._registered_routes[full_path] = wrapper | |
if ctx_ref: | |
self._route_contexts[full_path] = ctx_ref | |
# Remove existing rule if it exists | |
if full_path in self._the_rules.named_rules: | |
self._the_rules.named_rules.pop(full_path, None) | |
# Remove from rules list if it exists | |
self._the_rules.rules = [rule for rule in self._the_rules.rules | |
if rule.name != full_path] | |
# Add new rule | |
rule = (full_path, wrapper, {}, full_path) | |
self._the_rules.add_rules([rule]) | |
logger.info(f"Route registered: {path} -> {full_path} (globally={globally})") | |
def clear_function(self, path: str, globally: bool, trailing_slash: bool, session_id: str): | |
full_path = _RouteRegister._get_full_path(path, globally, trailing_slash, session_id) | |
# Remove from named rules | |
self._the_rules.named_rules.pop(full_path, None) | |
# Remove from our strong references | |
self._registered_routes.pop(full_path, None) | |
self._route_contexts.pop(full_path, None) | |
# Remove from rules list | |
self._the_rules.rules = [rule for rule in self._the_rules.rules | |
if rule.name != full_path] | |
logger.info(f"Route cleared: {full_path}") | |
def clear_all(self): | |
"""Clear all routes""" | |
self._the_rules.named_rules.clear() | |
self._the_rules.rules.clear() | |
self._registered_routes.clear() | |
self._route_contexts.clear() | |
self._deregists.clear() | |
logger.info("All routes cleared") | |
def get_registered_routes(self) -> Dict[str, Callable]: | |
"""Get all currently registered routes for debugging""" | |
return self._registered_routes.copy() | |
def st_route(path: str, globally: bool = False, trailing_slash: bool = True): | |
if not isinstance(path, str) or not path: | |
raise AttributeError('First argument must be a not empty path') | |
def wrap(f: Callable): | |
try: | |
rr: _RouteRegister = _RouteRegister.instance() | |
rr.regist_or_replace(path, globally, trailing_slash, f) | |
def clear_route(): | |
try: | |
ctx = get_script_run_ctx() | |
session_id = ctx.session_id if ctx else "global" | |
rr.clear_function(path, globally, trailing_slash, session_id) | |
except Exception as e: | |
logger.warning(f"Failed to clear route {path}: {e}") | |
setattr(f, 'clear', clear_route) | |
return f | |
except Exception as e: | |
logger.error(f"Failed to register route {path}: {e}") | |
raise | |
return wrap | |
def clear_all_routes(): | |
"""Clear all registered routes""" | |
try: | |
_RouteRegister.instance().clear_all() | |
except Exception as e: | |
logger.error(f"Failed to clear all routes: {e}") | |
setattr(st_route, 'clear', clear_all_routes) | |
''' | |
Basic usage: | |
``` | |
@st_route(path='get_name/(.*)') | |
def any_name( | |
path_args: List[str], | |
path_kwargs: Dict[str, Any], | |
method: str, | |
body: bytes, | |
arguments: Dict[str, Any] | |
) -> Union[int, bytes, str, Tuple[int, Union[bytes, str]], Tuple[int, Union[bytes, str], Dict[str, Any]]]: | |
""" | |
path_args: path regex unnamed groups | |
path_kwargs: path regex named groups | |
method: HTTP method | |
body: the request body in bytes | |
arguments: The query and body arguments | |
returns with any of the followings: | |
int: HTTP response code | |
bytes/str: HTTP 200 with body. str encoded with python default | |
Tuple[int, bytes/str]: HTTP response code with body | |
Tuple[int, bytes/str, Dict[str, Any]]: HTTP response code with body and additional headers | |
If you don't need any of the arguments, just use **kwargs. | |
""" | |
return "Hello " + path_args[0].decode() | |
get_script_run_ctx().session_id | |
``` | |
It exposes the `localhost:8501/<session_id>/get_name/anyone` and returns with HTTP200 and "Hello anyone". | |
So frontend can communicate with running code and can get already calculated values. | |
The functions wrapped with `st_route` can use `st.` commands but only until the script is run. | |
If you set `st_route` `globally` argument to `True`, the <session_id> is not added to the path. | |
''' |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment