Skip to content

Instantly share code, notes, and snippets.

@skasero
Forked from schaumb/st_route.py
Last active July 15, 2025 18:01
Show Gist options
  • Save skasero/c8d1a6e875003b19351b353c3c2a30f4 to your computer and use it in GitHub Desktop.
Save skasero/c8d1a6e875003b19351b353c3c2a30f4 to your computer and use it in GitHub Desktop.
Streamlit - Support custom HTTP requests
###########
## Original taken from https://gist.github.com/schaumb/d557dabf0beced7dfaa1be7acc09b1e4
## Modified and improved for better reliability
###########
import gc
import weakref
import logging
from typing import Optional, Callable, Union, Dict, Set
from collections import defaultdict
from streamlit.runtime import Runtime
from streamlit.runtime.scriptrunner import get_script_run_ctx, add_script_run_ctx
from streamlit.web.server.server_util import make_url_path_regex
from tornado import httputil
from tornado.httputil import HTTPServerRequest, ResponseStartLine, HTTPHeaders
from tornado.routing import Rule, AnyMatches, ReversibleRuleRouter
from tornado.web import Application
logger = logging.getLogger(__name__)
class _RouteRegister:
_instance: Optional['_RouteRegister'] = None
_application: Optional[Application] = None
@staticmethod
def handler(self: HTTPServerRequest, path_args, path_kwargs, func: Callable,
ctx: weakref.ref):
headers = HTTPHeaders()
old = get_script_run_ctx()
try:
current_ctx = ctx() if ctx else None
if current_ctx:
add_script_run_ctx(ctx=current_ctx)
res = func(path_args=path_args,
path_kwargs=path_kwargs,
method=self.method,
body=self.body,
arguments=self.arguments)
except Exception as e:
response_code = 500
response_body = f"Internal Server Error: {str(e)}"
logger.error(f"Route handler error: {e}")
headers.add("Content-Type", "text/plain")
self.connection.write(response_body)
self.connection.finish()
return
finally:
add_script_run_ctx(ctx=old)
response_code: int = 200
response_body: Optional[Union[bytes, str]] = None
if not isinstance(res, tuple):
if isinstance(res, int):
response_code = res
elif isinstance(res, bytes) or isinstance(res, str):
response_body = res
elif res is not None:
raise TypeError('Unknown return type from handler.')
else:
if len(res) == 2:
response_code, response_body = res
elif len(res) == 3:
response_code, response_body, header = res
for header_key, header_val in header.items():
headers.add(header_key, header_val)
else:
raise TypeError('Unknown return type from handler.')
if not isinstance(response_body, (bytes, str)):
response_body = str(response_body)
if isinstance(response_body, str):
response_body = response_body.encode("utf-8")
self.connection.write_headers(
ResponseStartLine(self.version, response_code, httputil.responses.get(response_code, "Unknown")),
headers,
)
self.connection.write(response_body)
self.connection.finish()
@classmethod
def _find_application(cls) -> Optional[Application]:
"""Find Tornado Application instance more reliably"""
if cls._application and cls._application.is_running_from_reloader is not None:
return cls._application
# Try multiple methods to find the application
try:
# Method 1: Check Runtime instance
runtime = Runtime.instance()
if hasattr(runtime, '_tornado_app'):
cls._application = runtime._tornado_app
return cls._application
except Exception:
pass
try:
# Method 2: GC search as fallback
for obj in gc.get_objects():
if isinstance(obj, Application):
cls._application = obj
return cls._application
except Exception:
pass
logger.warning("Could not find Tornado Application instance")
return None
@classmethod
def instance(cls) -> '_RouteRegister':
if cls._instance is None:
inst: Runtime = Runtime.instance()
res: Optional[_RouteRegister] = getattr(inst, '_streamlit_route_register', None)
if res is None:
app = cls._find_application()
if app is None:
raise RuntimeError("Could not find Tornado Application instance")
res = _RouteRegister()
app.add_handlers(".*", [Rule(AnyMatches(), res._the_rules)])
setattr(inst, '_streamlit_route_register', res)
cls._instance = res
logger.info("RouteRegister instance created and registered")
else:
cls._instance = res
return cls._instance
def __init__(self):
self._the_rules: ReversibleRuleRouter = ReversibleRuleRouter([])
self._the_rules.rules = []
self._registered_routes: Dict[str, Callable] = {}
self._route_contexts: Dict[str, weakref.ref] = {}
self._deregists: Dict[str, Set] = defaultdict(set)
logger.debug("RouteRegister initialized")
@staticmethod
def _get_full_path(path: str, globally: bool, trailing_slash: bool, session_id: str) -> str:
return make_url_path_regex(*(() if globally else (session_id,)),
path,
trailing_slash=trailing_slash)
def regist_or_replace(self, path: str, globally: bool,
trailing_slash: bool, f: Callable):
try:
ctx = get_script_run_ctx()
session_id = ctx.session_id if ctx else "global"
except Exception:
logger.warning("Could not get script context, using global session")
session_id = "global"
ctx = None
full_path = _RouteRegister._get_full_path(path, globally, trailing_slash, session_id)
def make_tornado_wrapper(func, ctx_ref):
def tornado_handler(request, *regex_args, **regex_kwargs):
path_args = list(regex_args)
path_kwargs = regex_kwargs
return _RouteRegister.handler(
request,
path_args=path_args,
path_kwargs=path_kwargs,
func=func,
ctx=ctx_ref
)
return tornado_handler
# Store strong reference to prevent garbage collection
ctx_ref = weakref.ref(ctx) if ctx else None
wrapper = make_tornado_wrapper(f, ctx_ref)
# Keep strong references
self._registered_routes[full_path] = wrapper
if ctx_ref:
self._route_contexts[full_path] = ctx_ref
# Remove existing rule if it exists
if full_path in self._the_rules.named_rules:
self._the_rules.named_rules.pop(full_path, None)
# Remove from rules list if it exists
self._the_rules.rules = [rule for rule in self._the_rules.rules
if rule.name != full_path]
# Add new rule
rule = (full_path, wrapper, {}, full_path)
self._the_rules.add_rules([rule])
logger.info(f"Route registered: {path} -> {full_path} (globally={globally})")
def clear_function(self, path: str, globally: bool, trailing_slash: bool, session_id: str):
full_path = _RouteRegister._get_full_path(path, globally, trailing_slash, session_id)
# Remove from named rules
self._the_rules.named_rules.pop(full_path, None)
# Remove from our strong references
self._registered_routes.pop(full_path, None)
self._route_contexts.pop(full_path, None)
# Remove from rules list
self._the_rules.rules = [rule for rule in self._the_rules.rules
if rule.name != full_path]
logger.info(f"Route cleared: {full_path}")
def clear_all(self):
"""Clear all routes"""
self._the_rules.named_rules.clear()
self._the_rules.rules.clear()
self._registered_routes.clear()
self._route_contexts.clear()
self._deregists.clear()
logger.info("All routes cleared")
def get_registered_routes(self) -> Dict[str, Callable]:
"""Get all currently registered routes for debugging"""
return self._registered_routes.copy()
def st_route(path: str, globally: bool = False, trailing_slash: bool = True):
if not isinstance(path, str) or not path:
raise AttributeError('First argument must be a not empty path')
def wrap(f: Callable):
try:
rr: _RouteRegister = _RouteRegister.instance()
rr.regist_or_replace(path, globally, trailing_slash, f)
def clear_route():
try:
ctx = get_script_run_ctx()
session_id = ctx.session_id if ctx else "global"
rr.clear_function(path, globally, trailing_slash, session_id)
except Exception as e:
logger.warning(f"Failed to clear route {path}: {e}")
setattr(f, 'clear', clear_route)
return f
except Exception as e:
logger.error(f"Failed to register route {path}: {e}")
raise
return wrap
def clear_all_routes():
"""Clear all registered routes"""
try:
_RouteRegister.instance().clear_all()
except Exception as e:
logger.error(f"Failed to clear all routes: {e}")
setattr(st_route, 'clear', clear_all_routes)
'''
Basic usage:
```
@st_route(path='get_name/(.*)')
def any_name(
path_args: List[str],
path_kwargs: Dict[str, Any],
method: str,
body: bytes,
arguments: Dict[str, Any]
) -> Union[int, bytes, str, Tuple[int, Union[bytes, str]], Tuple[int, Union[bytes, str], Dict[str, Any]]]:
"""
path_args: path regex unnamed groups
path_kwargs: path regex named groups
method: HTTP method
body: the request body in bytes
arguments: The query and body arguments
returns with any of the followings:
int: HTTP response code
bytes/str: HTTP 200 with body. str encoded with python default
Tuple[int, bytes/str]: HTTP response code with body
Tuple[int, bytes/str, Dict[str, Any]]: HTTP response code with body and additional headers
If you don't need any of the arguments, just use **kwargs.
"""
return "Hello " + path_args[0].decode()
get_script_run_ctx().session_id
```
It exposes the `localhost:8501/<session_id>/get_name/anyone` and returns with HTTP200 and "Hello anyone".
So frontend can communicate with running code and can get already calculated values.
The functions wrapped with `st_route` can use `st.` commands but only until the script is run.
If you set `st_route` `globally` argument to `True`, the <session_id> is not added to the path.
'''
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment