Skip to content

Instantly share code, notes, and snippets.

@Sinjhin
Created November 24, 2024 17:34
Show Gist options
  • Save Sinjhin/4d80dac7c736c90ff0ce4e78aadb29e8 to your computer and use it in GitHub Desktop.
Save Sinjhin/4d80dac7c736c90ff0ce4e78aadb29e8 to your computer and use it in GitHub Desktop.
Proctor For Multi Dim
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts
import random
from GADEv2_2.util.dataset_loaders.arc_dataset.loader import ARCDataLoader
from GADEv2_2.networks.multi_dim_trial import MultiDimTrial
class ZeTaskMan:
def __init__(self, params):
pass
class NoFuckingWay:
def __init__(self, params):
self.params = params
self.device = params['device']
self.data_loader = ARCDataLoader(params)
self.train_loader = self.data_loader.get_train_loader()
self.eval_loader = self.data_loader.get_eval_loader()
self.test_loader = self.data_loader.get_test_loader()
self.total_demos_solved = 0
self.total_demos = 0
self.total_train_tasks_solved = 0
self.total_train_tasks = 0
self.total_eval_tasks_solved = 0
self.total_eval_tasks = 0
self.errors = 0
task_data = self.data_loader.get_task(0, mode='train')
self.set_task_data(task_data)
def set_task_data(self, task_data):
self.task_data = task_data
self.task_input_shape = self.task_data['demo_inputs'][0].shape
self.task_output_shape = self.task_data['demo_outputs'][0].shape
self.task_num_demos = len(self.task_data['demo_inputs'])
self.total_demos += self.task_num_demos
self.multi_dim_input = torch.stack(self.task_data['demo_inputs']).to(self.device)
self.multi_dim_output = torch.stack(self.task_data['demo_outputs']).float().to(self.device)
self.test_inputs = self.task_data['test_inputs']
self.total_train_tasks += len(self.test_inputs)
self.test_outputs = self.task_data['test_outputs']
self.model = MultiDimTrial(
input_shape = self.task_input_shape,
output_shape = self.task_output_shape,
num_demos = self.task_num_demos,
mode='train'
).to(self.device)
self.learning_rate = self.params['learning_rate']
self.criterion = nn.SmoothL1Loss()
self.optimizer = optim.AdamW(self.model.parameters(), lr=0.001, weight_decay=0.01)
self.scheduler = CosineAnnealingWarmRestarts(
self.optimizer,
T_0=500,
T_mult=2,
eta_min=1e-6
)
self.max_grad_norm = 1.0
self.best_accuracy = 0
self.plateau_counter = 0
self.plateau_patience = 500
self.mismatch_in = 0
def calculate_accuracy(self, predictions, targets, epoch):
# Round predictions to nearest integer
predicted_ints = torch.round(predictions).to(self.device)
# Compare with targets
targets = targets.to(self.device)
correct = (predicted_ints == targets)
# Print detailed comparison
if correct.all():
print("\nAll elements match exactly!")
return 100.0
# Calculate accuracy
accuracy = correct.float().mean().item() * 100
if epoch % 100 == 0:
print('-' * 100)
print(f'\n\nEpoch [{epoch}/{self.params['epochs']}] accuracy: {accuracy:.2f}%')
# Find positions where they don't match
mismatches = torch.where(~correct)
print(f"\nFound {len(mismatches[0])} mismatches:")
for i in range(len(mismatches[0])):
pos = tuple(m[i].item() for m in mismatches)
self.mismatch_in = pos[0]
pred_val = predictions[pos].item()
target_val = targets[pos].item()
print(f"Position {pos}: Prediction (raw)={pred_val:.4f}, (rounded)={round(pred_val)}, Target={target_val}")
return accuracy
def train(self):
for epoch in range(self.params['epochs']):
self.model.train()
self.optimizer.zero_grad()
# Forward pass
predictions = self.model(self.multi_dim_input)
loss = self.criterion(predictions, self.multi_dim_output)
# Backward pass
loss.backward()
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.max_grad_norm)
self.optimizer.step()
self.model.eval()
with torch.no_grad():
train_predictions = self.model(self.multi_dim_input)
train_loss = self.criterion(train_predictions, self.multi_dim_output)
train_accuracy = self.calculate_accuracy(train_predictions, self.multi_dim_output, epoch)
self.scheduler.step(train_loss)
# Check for improvement
if train_accuracy > self.best_accuracy:
self.best_accuracy = train_accuracy
self.plateau_counter = 0
else:
self.plateau_counter += 1
# If we're stuck at a plateau, increase learning rate temporarily
if self.plateau_counter >= self.plateau_patience:
if random.random() < 0.5:
self.learning_rate *= 2.0
else:
self.learning_rate /= 2.0
for param_group in self.optimizer.param_groups:
param_group['lr'] = self.learning_rate
print(f"Stuck at plateau, changing learning rate temporarily: {self.optimizer.param_groups[0]['lr']}")
self.plateau_counter = 0
# Check if accuracy is effectively 100% (allowing for floating point error)
if train_accuracy > 99 or epoch == self.params['epochs'] - 1:
if train_accuracy > 99 and train_accuracy != 100.0:
print("\nEffectively perfect accuracy achieved!")
print(f'Epoch [{epoch+1}/{self.params['epochs']}]')
print(f'Training Loss: {train_loss.item():.4f}')
print(f'Training Accuracy: {train_accuracy:.2f}%')
# Verify all predictions match
all_match = torch.all(torch.round(train_predictions) == torch.round(self.multi_dim_output))
if all_match:
print("All predictions match targets exactly!")
else:
print("Some predictions still don't match exactly.")
if train_accuracy == 100.0 or epoch == self.params['epochs'] - 1:
if train_accuracy == 100.0:
print("Perfect accuracy achieved!")
self.total_demos_solved += self.task_num_demos
print(f"\nTraining completed at epoch {epoch + 1}")
print(f"Final Training Loss: {train_loss.item():.4f}")
print(f"Final Training Accuracy: {train_accuracy:.2f}%")
# Test the model
self.evaluate_tests()
break
else:
sample_idx = self.mismatch_in
print("\nSample Prediction vs Target:")
print(f"Predictions:\n{torch.round(train_predictions[sample_idx]).long()}")
print(f"Targets:\n{self.multi_dim_output[sample_idx].long()}")
print('-' * 80)
def evaluate_tests(self, set = 'train'):
"""Evaluate the model on test input with two attempts"""
self.model.eval()
self.model.set_mode('test')
for idx in range(len(self.test_inputs)):
with torch.no_grad():
test_input = self.test_inputs[idx].to(self.device)
test_output = self.test_outputs[idx].to(self.device)
# First attempt
test_prediction1 = self.model(test_input)
prediction1_correct = torch.all(torch.round(test_prediction1) == test_output)
# Second attempt
test_prediction2 = self.model(test_input)
prediction2_correct = torch.all(torch.round(test_prediction2) == test_output)
print("\nTest Results:")
print("First attempt prediction:")
print(torch.round(test_prediction1).cpu().long())
print("Second attempt prediction:")
print(torch.round(test_prediction2).cpu().long())
print("Target:")
print(self.test_outputs[idx].cpu().long())
if prediction1_correct or prediction2_correct:
print("Test passed! At least one prediction matched exactly.")
if set == 'eval':
self.total_eval_tasks_solved += 1
else:
self.total_train_tasks_solved += 1
else:
print("Test failed. Neither prediction matched exactly.")
def run(self):
for task in self.train_loader:
try:
self.set_task_data(task)
self.train()
except Exception as e:
print(f"Error occurred during training: {e}")
print("On task:", task['task_id'])
self.errors += 1
print("Continuing to next task...")
print(f"\n\nStats so far:")
print(f"Total demos solved: {self.total_demos_solved}/{self.total_demos}")
print(f"Total train tasks solved: {self.total_train_tasks_solved}/{self.total_train_tasks}")
print(f"Total eval tasks solved: {self.total_eval_tasks_solved}/{self.total_eval_tasks}")
print(f"Errors: {self.errors}")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment