Created
August 15, 2022 00:36
-
-
Save chelseaparlett/bd734748158d24e578f4daa9823f8c47 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from shiny import App, render, ui, reactive | |
from pathlib import Path | |
# Import modules for plot rendering | |
import numpy as np | |
import matplotlib.pyplot as plt | |
import seaborn as sns | |
import pandas as pd | |
# Import modeling packages | |
from sklearn.naive_bayes import GaussianNB | |
from sklearn.tree import DecisionTreeClassifier | |
from sklearn.linear_model import LogisticRegression | |
from sklearn.neighbors import KNeighborsClassifier | |
from sklearn.preprocessing import StandardScaler | |
def plotDecisionBoundary(Xdat,ydat, mod, model_type): | |
Xy = pd.concat([Xdat,ydat], axis = 1) | |
# grab range of values for plot | |
x0_range = np.linspace(-3, | |
3, num = 100) | |
x1_range = np.linspace(-3, | |
3, num = 100) | |
# get all possible points on graph | |
x0 = np.repeat(x0_range,1000) | |
x1 = np.tile(x1_range,1000) | |
x_grid = pd.DataFrame({Xdat.columns[0]: x0, Xdat.columns[1]: x1}) | |
# predict all background points | |
p = mod.predict(x_grid) | |
x_grid["p"] = p | |
# plot | |
fig, ax = plt.subplots() | |
ys = np.unique(ydat) | |
colors = ["#264653", "#2a9d8f", "#e9c46a"] | |
palette = {ys[i]: colors[i] for i in range(len(ys))} | |
tt = model_type + " Decision Boundary" | |
bound = sns.scatterplot(x = Xdat.columns[0], y = Xdat.columns[1], | |
hue = "p", ax = ax, data = x_grid, alpha = 0.1, | |
s = 5, legend = False, palette = palette) | |
bound = sns.scatterplot(x = Xdat.columns[0], y = Xdat.columns[1], | |
hue = ydat, ax = ax, data = Xy, palette = palette) | |
bound.set(xlim = [-3,3], ylim = [-3, 3], | |
title = tt) | |
bound.legend(loc='center left', bbox_to_anchor=(1, 0.5)) | |
return(fig) | |
def modeldata(Xdat,ydat,modeltype, depth = 10, k = 5): | |
if modeltype == "Logistic Regression": | |
mod = LogisticRegression() | |
elif modeltype == "Naive Bayes": | |
mod = GaussianNB() | |
elif modeltype == "Decision Tree": | |
mod = DecisionTreeClassifier(max_depth = depth) | |
elif modeltype == "KNN": | |
mod = KNeighborsClassifier(n_neighbors = k) | |
z = StandardScaler() | |
Xdat[Xdat.columns] = z.fit_transform(Xdat) | |
mod.fit(Xdat,ydat) | |
return(mod) | |
app_ui = ui.page_fluid( | |
ui.row( | |
ui.column(4,ui.input_select("dataset", "Choose a Data Set:", | |
["Palmer Penguins", "Iris", "Diabetes"]),), | |
ui.column(4,ui.input_slider("depth", "Max Depth", min = 1, max = 100, value = 10),), | |
ui.column(4,ui.input_slider("nneighbors", "Number of Neighbors", min = 1, | |
max = 100, value = 10),), | |
), | |
ui.input_action_button("go", "Create Decision Boundary"), | |
ui.row( | |
ui.column(6, | |
ui.output_plot("plot_lr"), | |
ui.output_plot("plot_nb"), | |
), | |
ui.column(6, | |
ui.output_plot("plot_dt"), | |
ui.output_plot("plot_knn"),) | |
), | |
) | |
def server(input, output, session): | |
X = reactive.Value(pd.DataFrame()) | |
y = reactive.Value(pd.Series()) | |
@reactive.Effect | |
@reactive.event(input.go) | |
def _(): | |
if input.dataset() == "Palmer Penguins": | |
infile = Path(__file__).parent / "penguins.csv" | |
df = pd.read_csv(infile) | |
df = df[["bill_length_mm", "bill_depth_mm", "species"]] | |
df.dropna(inplace = True) | |
X.set(df[["bill_length_mm", "bill_depth_mm"]]) | |
y.set(df["species"]) | |
elif input.dataset() == "Diabetes": | |
infile = Path(__file__).parent / "diabetes.csv" | |
df = pd.read_csv(infile) | |
df = df[["Glucose", "BloodPressure", "Outcome"]] | |
df.dropna(inplace = True) | |
X.set(df[["Glucose", "BloodPressure"]]) | |
y.set(df["Outcome"]) | |
elif input.dataset() == "Iris": | |
infile = Path(__file__).parent / "iris.csv" | |
df = pd.read_csv(infile) | |
df = df[["sepal_length", "sepal_width", "species"]] | |
df.dropna(inplace = True) | |
X.set(df[["sepal_length", "sepal_width"]]) | |
y.set(df["species"]) | |
else: | |
print("ISSUE") | |
@output | |
@render.plot() | |
def plot_lr(): | |
mod = modeldata(X.get(), y.get(), "Logistic Regression") | |
f = plotDecisionBoundary(X.get(),y.get(), mod, "Logistic Regression") | |
return(f) | |
@output | |
@render.plot() | |
def plot_nb(): | |
mod = modeldata(X.get(), y.get(), "Naive Bayes") | |
f = plotDecisionBoundary(X.get(),y.get(), mod, "Naive Bayes") | |
return(f) | |
@output | |
@render.plot() | |
def plot_dt(): | |
mod = modeldata(X.get(), y.get(), "Decision Tree", depth = input.depth()) | |
f = plotDecisionBoundary(X.get(),y.get(), mod, "Decision Tree") | |
return(f) | |
@output | |
@render.plot() | |
def plot_knn(): | |
mod = modeldata(X.get(), y.get(), "KNN", k = input.nneighbors()) | |
f = plotDecisionBoundary(X.get(),y.get(), mod, "KNN") | |
return(f) | |
app = App(app_ui, server, debug=True) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment