Skip to content

Instantly share code, notes, and snippets.

@tail-call
Created November 28, 2024 13:24
Show Gist options
  • Save tail-call/f566505c2fa89ed0ee1346d487889704 to your computer and use it in GitHub Desktop.
Save tail-call/f566505c2fa89ed0ee1346d487889704 to your computer and use it in GitHub Desktop.
import curses
from curses import wrapper
from typing import Any
import torch
import torch.nn as nn
import torch.optim as optim
from cgtnnlib.datasets import breast_cancer_dataset, car_evaluation_dataset, student_performance_factors_dataset
from cgtnnlib.AugmentedReLUNetwork import AugmentedReLUNetwork
from cgtnnlib.RegularNetwork import RegularNetwork
from cgtnnlib.training import train_model
from cgtnnlib.common import evaluate_classification_model, evaluate_regression_model
# App's state
models: list[Any] = []
evaluations = []
def main(stdscr):
curses.curs_set(0) # Hide the cursor
stdscr.clear()
stdscr.refresh()
current_row = 0
menu = ['Train Model', 'List Models', 'Run Evaluation', 'Exit']
while True:
stdscr.clear()
h, w = stdscr.getmaxyx()
for idx, row in enumerate(menu):
x = w // 2 - len(row) // 2
y = h // 2 - len(menu) // 2 + idx
if idx == current_row:
stdscr.attron(curses.color_pair(1))
stdscr.addstr(y, x, row)
stdscr.attroff(curses.color_pair(1))
else:
stdscr.addstr(y, x, row)
stdscr.refresh()
key = stdscr.getch()
if key == curses.KEY_UP and current_row > 0:
current_row -= 1
elif key == curses.KEY_DOWN and current_row < len(menu) - 1:
current_row += 1
elif key == curses.KEY_ENTER or key in [10, 13]:
if current_row == 0:
train_model_menu(stdscr)
elif current_row == 1:
list_models_menu(stdscr)
elif current_row == 2:
run_evaluation_menu(stdscr)
elif current_row == 3:
break
def train_model_menu(stdscr):
stdscr.clear()
stdscr.addstr(0, 0, "Select Dataset:")
datasets = ['Breast Cancer', 'Car Evaluation', 'Student Performance Factors']
for idx, dataset in enumerate(datasets):
stdscr.addstr(idx + 1, 0, f"{idx + 1}. {dataset}")
stdscr.refresh()
key = stdscr.getch()
selected_dataset = int(chr(key)) - 1
if selected_dataset == 0:
X, y = breast_cancer_dataset()
elif selected_dataset == 1:
X, y = car_evaluation_dataset()
elif selected_dataset == 2:
X, y = student_performance_factors_dataset()
model = AugmentedReLUNetwork(inputs_count=X.shape[1], outputs_count=1, p=0.5)
optimizer = optim.Adam(model.parameters(), lr=0.001)
train_model(model, X, y, optimizer, epochs=10)
models.append(model)
def list_models_menu(stdscr):
stdscr.clear()
stdscr.addstr(0, 0, "List of Models:")
for idx, model in enumerate(models):
stdscr.addstr(idx + 1, 0, f"{idx + 1}. {model}")
stdscr.refresh()
stdscr.getch()
def run_evaluation_menu(stdscr):
stdscr.clear()
stdscr.addstr(0, 0, "Select Model to Evaluate:")
for idx, model in enumerate(models):
stdscr.addstr(idx + 1, 0, f"{idx + 1}. {model}")
stdscr.refresh()
key = stdscr.getch()
selected_model = int(chr(key)) - 1
if selected_model < len(models):
model = models[selected_model]
X, y = breast_cancer_dataset() # Example dataset
evaluation = evaluate_classification_model(model, X, y)
evaluations.append(evaluation)
stdscr.clear()
stdscr.addstr(0, 0, "Evaluation Results:")
for idx, result in enumerate(evaluation):
stdscr.addstr(idx + 1, 0, f"{result}")
stdscr.refresh()
stdscr.getch()
if __name__ == "__main__":
wrapper(main)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment