Last active
August 21, 2024 15:24
-
-
Save skannan-maf/f96c8e0774ea261bf5acd9a387d5fa1c to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import streamlit as st | |
import pandas as pd | |
from pandas.api.types import ( | |
is_categorical_dtype, | |
is_datetime64_any_dtype, | |
is_numeric_dtype, | |
is_object_dtype, | |
) | |
from datetime import datetime | |
# | |
# DATA FRAME FILTER FOR STREAMLIT (Simple code, Lot of features, Works like a charm) | |
# | |
# The function below is mostly from https://github.com/tylerjrichards/st-filter-dataframe/blob/main/streamlit_app.py | |
# Minor modifications done to suit our purposes (like including vs excluding nulls, conditions to interpret a col as a categorical etc.) | |
# | |
# ALTERNATIVES EXPLORED | |
# Another alternative that I considered was "streamlit-dynamic-filters" package (https://github.com/arsentievalex/streamlit-dynamic-filters) | |
# But it is suitable only for filtering through categorical columns | |
# The one we use now, supports numeric and date columns as well which is more powerful! | |
# | |
# | |
def filter_dataframe(df, base_key, modification_container=None, as_categorical=[]): | |
""" | |
Adds a UI on top of a dataframe to let viewers filter columns | |
Args: | |
df (pd.DataFrame): Original dataframe | |
base_key: Base of the widget key to be used for keys in this routine | |
modification_container: A container within which all the new widgets will be created. If None, a new container will be created | |
as_categorical: List of column names that the caller can ask to be treated as categorical columns | |
Returns: | |
pd.DataFrame: Filtered dataframe | |
""" | |
df = df.copy() | |
include_nulls = False | |
if st.checkbox('Include nulls?', value=include_nulls, key= base_key + '/inclnulls'): | |
include_nulls = True | |
# Try to convert datetimes into a standard format (datetime, no timezone) | |
for col in df.columns: | |
if is_object_dtype(df[col]): | |
try: | |
df[col] = pd.to_datetime(df[col]) | |
except Exception: | |
pass | |
if is_datetime64_any_dtype(df[col]): | |
df[col] = df[col].dt.tz_localize(None) | |
if modification_container is None: | |
modification_container = st.container() | |
with modification_container: | |
to_filter_columns = st.multiselect("Filter dataframe on", df.dropna(how='all', axis=1).columns, key=base_key+'/to_filter') | |
for column in to_filter_columns: | |
left, right = st.columns((1, 20)) | |
left.write("↳") | |
this_column = df[column].dropna() | |
# Treat columns with < 10 unique values as categorical | |
if (is_categorical_dtype(this_column)) or (column in as_categorical) or ((is_numeric_dtype(this_column)==True) and (len(this_column.unique()) == 1)) or ((is_datetime64_any_dtype(this_column) == False) and (is_numeric_dtype(this_column)==False) and (len(this_column.unique()) < 20)): | |
user_cat_input = right.multiselect( | |
f"Values for {column}", | |
this_column.unique(), | |
default=list(this_column.unique())[0], | |
key=base_key+column | |
) | |
if len(user_cat_input) > 0: | |
if include_nulls == False: | |
df = df[df[column].isin(user_cat_input)] | |
else: | |
df = df[df[column].isin(user_cat_input) | df[column].isnull()] | |
elif is_numeric_dtype(this_column): | |
_min = float(this_column.min()) | |
_max = float(this_column.max()) | |
step = (_max - _min) / 100 | |
user_num_input = right.slider( | |
f"Values for {column}", | |
_min, | |
_max, | |
(_min, _max), | |
step=step, | |
format='%.4f', | |
key=base_key+column | |
) | |
if len(user_num_input) == 2: | |
if include_nulls == False: | |
df = df[df[column].between(*user_num_input)] | |
else: | |
df = df[df[column].between(*user_num_input) | df[column].isnull()] | |
elif is_datetime64_any_dtype(this_column): | |
user_date_input = right.date_input( | |
f"Values for {column}", | |
value=( | |
datetime(this_column.min().year, 1, 1), | |
datetime(this_column.max().year, 12, 31) | |
), | |
key=base_key+column | |
) | |
if len(user_date_input) == 2: | |
user_date_input = tuple(map(pd.to_datetime, user_date_input)) | |
start_date, end_date = user_date_input | |
if include_nulls == False: | |
df = df.loc[df[column].between(start_date, end_date)] | |
else: | |
df = df.loc[df[column].between(start_date, end_date) | df[column].isnull()] | |
else: | |
user_text_input = right.text_input( | |
f"Substring or regex in {column}", | |
key=base_key+column | |
) | |
if user_text_input: | |
if include_nulls == False: | |
df = df[df[column].str.contains(user_text_input)] | |
else: | |
df = df[df[column].str.contains(user_text_input) | df[column].isnull()] | |
return df |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment