Created
November 24, 2024 17:34
-
-
Save Sinjhin/4d80dac7c736c90ff0ce4e78aadb29e8 to your computer and use it in GitHub Desktop.
Proctor For Multi Dim
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.nn as nn | |
import torch.optim as optim | |
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts | |
import random | |
from GADEv2_2.util.dataset_loaders.arc_dataset.loader import ARCDataLoader | |
from GADEv2_2.networks.multi_dim_trial import MultiDimTrial | |
class ZeTaskMan: | |
def __init__(self, params): | |
pass | |
class NoFuckingWay: | |
def __init__(self, params): | |
self.params = params | |
self.device = params['device'] | |
self.data_loader = ARCDataLoader(params) | |
self.train_loader = self.data_loader.get_train_loader() | |
self.eval_loader = self.data_loader.get_eval_loader() | |
self.test_loader = self.data_loader.get_test_loader() | |
self.total_demos_solved = 0 | |
self.total_demos = 0 | |
self.total_train_tasks_solved = 0 | |
self.total_train_tasks = 0 | |
self.total_eval_tasks_solved = 0 | |
self.total_eval_tasks = 0 | |
self.errors = 0 | |
task_data = self.data_loader.get_task(0, mode='train') | |
self.set_task_data(task_data) | |
def set_task_data(self, task_data): | |
self.task_data = task_data | |
self.task_input_shape = self.task_data['demo_inputs'][0].shape | |
self.task_output_shape = self.task_data['demo_outputs'][0].shape | |
self.task_num_demos = len(self.task_data['demo_inputs']) | |
self.total_demos += self.task_num_demos | |
self.multi_dim_input = torch.stack(self.task_data['demo_inputs']).to(self.device) | |
self.multi_dim_output = torch.stack(self.task_data['demo_outputs']).float().to(self.device) | |
self.test_inputs = self.task_data['test_inputs'] | |
self.total_train_tasks += len(self.test_inputs) | |
self.test_outputs = self.task_data['test_outputs'] | |
self.model = MultiDimTrial( | |
input_shape = self.task_input_shape, | |
output_shape = self.task_output_shape, | |
num_demos = self.task_num_demos, | |
mode='train' | |
).to(self.device) | |
self.learning_rate = self.params['learning_rate'] | |
self.criterion = nn.SmoothL1Loss() | |
self.optimizer = optim.AdamW(self.model.parameters(), lr=0.001, weight_decay=0.01) | |
self.scheduler = CosineAnnealingWarmRestarts( | |
self.optimizer, | |
T_0=500, | |
T_mult=2, | |
eta_min=1e-6 | |
) | |
self.max_grad_norm = 1.0 | |
self.best_accuracy = 0 | |
self.plateau_counter = 0 | |
self.plateau_patience = 500 | |
self.mismatch_in = 0 | |
def calculate_accuracy(self, predictions, targets, epoch): | |
# Round predictions to nearest integer | |
predicted_ints = torch.round(predictions).to(self.device) | |
# Compare with targets | |
targets = targets.to(self.device) | |
correct = (predicted_ints == targets) | |
# Print detailed comparison | |
if correct.all(): | |
print("\nAll elements match exactly!") | |
return 100.0 | |
# Calculate accuracy | |
accuracy = correct.float().mean().item() * 100 | |
if epoch % 100 == 0: | |
print('-' * 100) | |
print(f'\n\nEpoch [{epoch}/{self.params['epochs']}] accuracy: {accuracy:.2f}%') | |
# Find positions where they don't match | |
mismatches = torch.where(~correct) | |
print(f"\nFound {len(mismatches[0])} mismatches:") | |
for i in range(len(mismatches[0])): | |
pos = tuple(m[i].item() for m in mismatches) | |
self.mismatch_in = pos[0] | |
pred_val = predictions[pos].item() | |
target_val = targets[pos].item() | |
print(f"Position {pos}: Prediction (raw)={pred_val:.4f}, (rounded)={round(pred_val)}, Target={target_val}") | |
return accuracy | |
def train(self): | |
for epoch in range(self.params['epochs']): | |
self.model.train() | |
self.optimizer.zero_grad() | |
# Forward pass | |
predictions = self.model(self.multi_dim_input) | |
loss = self.criterion(predictions, self.multi_dim_output) | |
# Backward pass | |
loss.backward() | |
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.max_grad_norm) | |
self.optimizer.step() | |
self.model.eval() | |
with torch.no_grad(): | |
train_predictions = self.model(self.multi_dim_input) | |
train_loss = self.criterion(train_predictions, self.multi_dim_output) | |
train_accuracy = self.calculate_accuracy(train_predictions, self.multi_dim_output, epoch) | |
self.scheduler.step(train_loss) | |
# Check for improvement | |
if train_accuracy > self.best_accuracy: | |
self.best_accuracy = train_accuracy | |
self.plateau_counter = 0 | |
else: | |
self.plateau_counter += 1 | |
# If we're stuck at a plateau, increase learning rate temporarily | |
if self.plateau_counter >= self.plateau_patience: | |
if random.random() < 0.5: | |
self.learning_rate *= 2.0 | |
else: | |
self.learning_rate /= 2.0 | |
for param_group in self.optimizer.param_groups: | |
param_group['lr'] = self.learning_rate | |
print(f"Stuck at plateau, changing learning rate temporarily: {self.optimizer.param_groups[0]['lr']}") | |
self.plateau_counter = 0 | |
# Check if accuracy is effectively 100% (allowing for floating point error) | |
if train_accuracy > 99 or epoch == self.params['epochs'] - 1: | |
if train_accuracy > 99 and train_accuracy != 100.0: | |
print("\nEffectively perfect accuracy achieved!") | |
print(f'Epoch [{epoch+1}/{self.params['epochs']}]') | |
print(f'Training Loss: {train_loss.item():.4f}') | |
print(f'Training Accuracy: {train_accuracy:.2f}%') | |
# Verify all predictions match | |
all_match = torch.all(torch.round(train_predictions) == torch.round(self.multi_dim_output)) | |
if all_match: | |
print("All predictions match targets exactly!") | |
else: | |
print("Some predictions still don't match exactly.") | |
if train_accuracy == 100.0 or epoch == self.params['epochs'] - 1: | |
if train_accuracy == 100.0: | |
print("Perfect accuracy achieved!") | |
self.total_demos_solved += self.task_num_demos | |
print(f"\nTraining completed at epoch {epoch + 1}") | |
print(f"Final Training Loss: {train_loss.item():.4f}") | |
print(f"Final Training Accuracy: {train_accuracy:.2f}%") | |
# Test the model | |
self.evaluate_tests() | |
break | |
else: | |
sample_idx = self.mismatch_in | |
print("\nSample Prediction vs Target:") | |
print(f"Predictions:\n{torch.round(train_predictions[sample_idx]).long()}") | |
print(f"Targets:\n{self.multi_dim_output[sample_idx].long()}") | |
print('-' * 80) | |
def evaluate_tests(self, set = 'train'): | |
"""Evaluate the model on test input with two attempts""" | |
self.model.eval() | |
self.model.set_mode('test') | |
for idx in range(len(self.test_inputs)): | |
with torch.no_grad(): | |
test_input = self.test_inputs[idx].to(self.device) | |
test_output = self.test_outputs[idx].to(self.device) | |
# First attempt | |
test_prediction1 = self.model(test_input) | |
prediction1_correct = torch.all(torch.round(test_prediction1) == test_output) | |
# Second attempt | |
test_prediction2 = self.model(test_input) | |
prediction2_correct = torch.all(torch.round(test_prediction2) == test_output) | |
print("\nTest Results:") | |
print("First attempt prediction:") | |
print(torch.round(test_prediction1).cpu().long()) | |
print("Second attempt prediction:") | |
print(torch.round(test_prediction2).cpu().long()) | |
print("Target:") | |
print(self.test_outputs[idx].cpu().long()) | |
if prediction1_correct or prediction2_correct: | |
print("Test passed! At least one prediction matched exactly.") | |
if set == 'eval': | |
self.total_eval_tasks_solved += 1 | |
else: | |
self.total_train_tasks_solved += 1 | |
else: | |
print("Test failed. Neither prediction matched exactly.") | |
def run(self): | |
for task in self.train_loader: | |
try: | |
self.set_task_data(task) | |
self.train() | |
except Exception as e: | |
print(f"Error occurred during training: {e}") | |
print("On task:", task['task_id']) | |
self.errors += 1 | |
print("Continuing to next task...") | |
print(f"\n\nStats so far:") | |
print(f"Total demos solved: {self.total_demos_solved}/{self.total_demos}") | |
print(f"Total train tasks solved: {self.total_train_tasks_solved}/{self.total_train_tasks}") | |
print(f"Total eval tasks solved: {self.total_eval_tasks_solved}/{self.total_eval_tasks}") | |
print(f"Errors: {self.errors}") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment