Last active
February 27, 2024 06:00
-
-
Save ash2shukla/ff180d7fbe8ec3a0240f19f4452acde7 to your computer and use it in GitHub Desktop.
state decorator
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from streamlit.report_thread import get_report_ctx | |
from streamlit.hashing import _CodeHasher | |
from streamlit.server.server import Server | |
from prometheus_client.registry import REGISTRY | |
from prometheus_client import Counter | |
class _SessionState: | |
def __init__(self, session, hash_funcs): | |
"""Initialize SessionState instance.""" | |
self.__dict__["_state"] = { | |
"data": {}, | |
"hash": None, | |
"hasher": _CodeHasher(hash_funcs), | |
"is_rerun": False, | |
"session": session, | |
} | |
def __call__(self, **kwargs): | |
"""Initialize state data once.""" | |
for item, value in kwargs.items(): | |
if item not in self._state["data"]: | |
self._state["data"][item] = value | |
def __getitem__(self, item): | |
"""Return a saved state value, None if item is undefined.""" | |
return self._state["data"].get(item, None) | |
def __getattr__(self, item): | |
"""Return a saved state value, None if item is undefined.""" | |
return self._state["data"].get(item, None) | |
def __setitem__(self, item, value): | |
"""Set state value.""" | |
self._state["data"][item] = value | |
def __setattr__(self, item, value): | |
"""Set state value.""" | |
self._state["data"][item] = value | |
def clear(self): | |
"""Clear session state and request a rerun.""" | |
self._state["data"].clear() | |
self._state["session"].request_rerun() | |
def sync(self): | |
"""Rerun the app with all state values up to date from the beginning to fix rollbacks.""" | |
# Ensure to rerun only once to avoid infinite loops | |
# caused by a constantly changing state value at each run. | |
# | |
# Example: state.value += 1 | |
if self._state["is_rerun"]: | |
self._state["is_rerun"] = False | |
elif self._state["hash"] is not None: | |
if self._state["hash"] != self._state["hasher"].to_bytes( | |
self._state["data"], None | |
): | |
self._state["is_rerun"] = True | |
self._state["session"].request_rerun() | |
self._state["hash"] = self._state["hasher"].to_bytes(self._state["data"], None) | |
def _get_session(): | |
session_id = get_report_ctx().session_id | |
session_info = Server.get_current()._get_session_info(session_id) | |
if session_info is None: | |
raise RuntimeError("Couldn't get your Streamlit Session object.") | |
return session_info.session | |
def get_state(hash_funcs=None): | |
session = _get_session() | |
if not hasattr(session, "_custom_session_state"): | |
session._custom_session_state = _SessionState(session, hash_funcs) | |
return session._custom_session_state | |
def _get_names(collector_name, collector_type): | |
result = [] | |
type_suffixes = { | |
'counter': ['_total', '_created'], | |
'summary': ['_sum', '_count', '_created'], | |
'histogram': ['_bucket', '_sum', '_count', '_created'], | |
'gaugehistogram': ['_bucket', '_gsum', '_gcount'], | |
'info': ['_info'], | |
} | |
for suffix in type_suffixes.get(collector_type, []): | |
result.append(collector_name + suffix) | |
return result | |
def get_or_create_metric(metric_type, *, name, **kwargs): | |
names = _get_names(name, metric_type.__name__.lower()) | |
if any(name in REGISTRY._names_to_collectors for name in names): | |
return REGISTRY._names_to_collectors[names[0]] | |
else: | |
return metric_type(name=name, **kwargs) | |
def provide_state(func): | |
def wrapper(*args, **kwargs): | |
state = get_state(hash_funcs={}) | |
count_sessions() | |
return_value = func(state=state, *args, **kwargs) | |
state.sync() | |
return return_value | |
return wrapper | |
def count_sessions(): | |
state = get_state(hash_funcs={}) | |
session_counter = get_or_create_metric(Counter, name='session_count', documentation='Unique Sessions') | |
if not state._is_session_reused: | |
session_counter.inc() | |
state._is_session_reused = True |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Added session counter explicitly ( and a way to integrate prometheus metrics with session in general.. )
No need to call count_sessions explicitly if you are already using @provide_state decorator for your main method.