Created
November 28, 2024 13:24
-
-
Save tail-call/f566505c2fa89ed0ee1346d487889704 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import curses | |
from curses import wrapper | |
from typing import Any | |
import torch | |
import torch.nn as nn | |
import torch.optim as optim | |
from cgtnnlib.datasets import breast_cancer_dataset, car_evaluation_dataset, student_performance_factors_dataset | |
from cgtnnlib.AugmentedReLUNetwork import AugmentedReLUNetwork | |
from cgtnnlib.RegularNetwork import RegularNetwork | |
from cgtnnlib.training import train_model | |
from cgtnnlib.common import evaluate_classification_model, evaluate_regression_model | |
# App's state | |
models: list[Any] = [] | |
evaluations = [] | |
def main(stdscr): | |
curses.curs_set(0) # Hide the cursor | |
stdscr.clear() | |
stdscr.refresh() | |
current_row = 0 | |
menu = ['Train Model', 'List Models', 'Run Evaluation', 'Exit'] | |
while True: | |
stdscr.clear() | |
h, w = stdscr.getmaxyx() | |
for idx, row in enumerate(menu): | |
x = w // 2 - len(row) // 2 | |
y = h // 2 - len(menu) // 2 + idx | |
if idx == current_row: | |
stdscr.attron(curses.color_pair(1)) | |
stdscr.addstr(y, x, row) | |
stdscr.attroff(curses.color_pair(1)) | |
else: | |
stdscr.addstr(y, x, row) | |
stdscr.refresh() | |
key = stdscr.getch() | |
if key == curses.KEY_UP and current_row > 0: | |
current_row -= 1 | |
elif key == curses.KEY_DOWN and current_row < len(menu) - 1: | |
current_row += 1 | |
elif key == curses.KEY_ENTER or key in [10, 13]: | |
if current_row == 0: | |
train_model_menu(stdscr) | |
elif current_row == 1: | |
list_models_menu(stdscr) | |
elif current_row == 2: | |
run_evaluation_menu(stdscr) | |
elif current_row == 3: | |
break | |
def train_model_menu(stdscr): | |
stdscr.clear() | |
stdscr.addstr(0, 0, "Select Dataset:") | |
datasets = ['Breast Cancer', 'Car Evaluation', 'Student Performance Factors'] | |
for idx, dataset in enumerate(datasets): | |
stdscr.addstr(idx + 1, 0, f"{idx + 1}. {dataset}") | |
stdscr.refresh() | |
key = stdscr.getch() | |
selected_dataset = int(chr(key)) - 1 | |
if selected_dataset == 0: | |
X, y = breast_cancer_dataset() | |
elif selected_dataset == 1: | |
X, y = car_evaluation_dataset() | |
elif selected_dataset == 2: | |
X, y = student_performance_factors_dataset() | |
model = AugmentedReLUNetwork(inputs_count=X.shape[1], outputs_count=1, p=0.5) | |
optimizer = optim.Adam(model.parameters(), lr=0.001) | |
train_model(model, X, y, optimizer, epochs=10) | |
models.append(model) | |
def list_models_menu(stdscr): | |
stdscr.clear() | |
stdscr.addstr(0, 0, "List of Models:") | |
for idx, model in enumerate(models): | |
stdscr.addstr(idx + 1, 0, f"{idx + 1}. {model}") | |
stdscr.refresh() | |
stdscr.getch() | |
def run_evaluation_menu(stdscr): | |
stdscr.clear() | |
stdscr.addstr(0, 0, "Select Model to Evaluate:") | |
for idx, model in enumerate(models): | |
stdscr.addstr(idx + 1, 0, f"{idx + 1}. {model}") | |
stdscr.refresh() | |
key = stdscr.getch() | |
selected_model = int(chr(key)) - 1 | |
if selected_model < len(models): | |
model = models[selected_model] | |
X, y = breast_cancer_dataset() # Example dataset | |
evaluation = evaluate_classification_model(model, X, y) | |
evaluations.append(evaluation) | |
stdscr.clear() | |
stdscr.addstr(0, 0, "Evaluation Results:") | |
for idx, result in enumerate(evaluation): | |
stdscr.addstr(idx + 1, 0, f"{result}") | |
stdscr.refresh() | |
stdscr.getch() | |
if __name__ == "__main__": | |
wrapper(main) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment