Skip to content

Instantly share code, notes, and snippets.

@skannan-maf
Last active August 21, 2024 15:24
Show Gist options
  • Save skannan-maf/f96c8e0774ea261bf5acd9a387d5fa1c to your computer and use it in GitHub Desktop.
Save skannan-maf/f96c8e0774ea261bf5acd9a387d5fa1c to your computer and use it in GitHub Desktop.
import streamlit as st
import pandas as pd
from pandas.api.types import (
is_categorical_dtype,
is_datetime64_any_dtype,
is_numeric_dtype,
is_object_dtype,
)
from datetime import datetime
#
# DATA FRAME FILTER FOR STREAMLIT (Simple code, Lot of features, Works like a charm)
#
# The function below is mostly from https://github.com/tylerjrichards/st-filter-dataframe/blob/main/streamlit_app.py
# Minor modifications done to suit our purposes (like including vs excluding nulls, conditions to interpret a col as a categorical etc.)
#
# ALTERNATIVES EXPLORED
# Another alternative that I considered was "streamlit-dynamic-filters" package (https://github.com/arsentievalex/streamlit-dynamic-filters)
# But it is suitable only for filtering through categorical columns
# The one we use now, supports numeric and date columns as well which is more powerful!
#
#
def filter_dataframe(df, base_key, modification_container=None, as_categorical=[]):
"""
Adds a UI on top of a dataframe to let viewers filter columns
Args:
df (pd.DataFrame): Original dataframe
base_key: Base of the widget key to be used for keys in this routine
modification_container: A container within which all the new widgets will be created. If None, a new container will be created
as_categorical: List of column names that the caller can ask to be treated as categorical columns
Returns:
pd.DataFrame: Filtered dataframe
"""
df = df.copy()
include_nulls = False
if st.checkbox('Include nulls?', value=include_nulls, key= base_key + '/inclnulls'):
include_nulls = True
# Try to convert datetimes into a standard format (datetime, no timezone)
for col in df.columns:
if is_object_dtype(df[col]):
try:
df[col] = pd.to_datetime(df[col])
except Exception:
pass
if is_datetime64_any_dtype(df[col]):
df[col] = df[col].dt.tz_localize(None)
if modification_container is None:
modification_container = st.container()
with modification_container:
to_filter_columns = st.multiselect("Filter dataframe on", df.dropna(how='all', axis=1).columns, key=base_key+'/to_filter')
for column in to_filter_columns:
left, right = st.columns((1, 20))
left.write("↳")
this_column = df[column].dropna()
# Treat columns with < 10 unique values as categorical
if (is_categorical_dtype(this_column)) or (column in as_categorical) or ((is_numeric_dtype(this_column)==True) and (len(this_column.unique()) == 1)) or ((is_datetime64_any_dtype(this_column) == False) and (is_numeric_dtype(this_column)==False) and (len(this_column.unique()) < 20)):
user_cat_input = right.multiselect(
f"Values for {column}",
this_column.unique(),
default=list(this_column.unique())[0],
key=base_key+column
)
if len(user_cat_input) > 0:
if include_nulls == False:
df = df[df[column].isin(user_cat_input)]
else:
df = df[df[column].isin(user_cat_input) | df[column].isnull()]
elif is_numeric_dtype(this_column):
_min = float(this_column.min())
_max = float(this_column.max())
step = (_max - _min) / 100
user_num_input = right.slider(
f"Values for {column}",
_min,
_max,
(_min, _max),
step=step,
format='%.4f',
key=base_key+column
)
if len(user_num_input) == 2:
if include_nulls == False:
df = df[df[column].between(*user_num_input)]
else:
df = df[df[column].between(*user_num_input) | df[column].isnull()]
elif is_datetime64_any_dtype(this_column):
user_date_input = right.date_input(
f"Values for {column}",
value=(
datetime(this_column.min().year, 1, 1),
datetime(this_column.max().year, 12, 31)
),
key=base_key+column
)
if len(user_date_input) == 2:
user_date_input = tuple(map(pd.to_datetime, user_date_input))
start_date, end_date = user_date_input
if include_nulls == False:
df = df.loc[df[column].between(start_date, end_date)]
else:
df = df.loc[df[column].between(start_date, end_date) | df[column].isnull()]
else:
user_text_input = right.text_input(
f"Substring or regex in {column}",
key=base_key+column
)
if user_text_input:
if include_nulls == False:
df = df[df[column].str.contains(user_text_input)]
else:
df = df[df[column].str.contains(user_text_input) | df[column].isnull()]
return df
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment