Skip to content

Instantly share code, notes, and snippets.

@Sinjhin
Created November 16, 2024 23:57
Show Gist options
  • Save Sinjhin/4c6699cc529ed3a9286ea2c2e01946f4 to your computer and use it in GitHub Desktop.
Save Sinjhin/4c6699cc529ed3a9286ea2c2e01946f4 to your computer and use it in GitHub Desktop.
ARC-AGI Loader
import os
import json
import torch
import numpy as np
from torch.utils.data import Dataset, DataLoader
class ARCDataset(Dataset):
def __init__(self, challenges_file, solutions_file=None):
self.tasks = []
# Load challenges
with open(challenges_file, 'r') as f:
challenges = json.load(f)
# Load solutions (if provided)
if solutions_file:
with open(solutions_file, 'r') as f:
solutions = json.load(f)
# Solutions is a dictionary with task IDs as keys
solutions_dict = solutions
else:
solutions_dict = {}
# Iterate over tasks
for task_id, task_data in challenges.items():
task = {'task_id': task_id}
# Process training examples
demo_examples = task_data.get('train', [])
demo_inputs = []
demo_outputs = []
for example in demo_examples:
input_grid = example['input']
output_grid = example['output']
input_tensor = torch.tensor(input_grid, dtype=torch.long)
output_tensor = torch.tensor(output_grid, dtype=torch.long)
demo_inputs.append(input_tensor)
demo_outputs.append(output_tensor)
task['demo_inputs'] = demo_inputs
task['demo_outputs'] = demo_outputs
# Process evaluation (test) examples
# test_input = task_data.get('test', [])[0].get('input', [])
tests = task_data.get('test', [])
solutions = solutions_dict.get(task_id, None)
test_inputs = []
test_outputs = []
for i, test in enumerate(tests):
test_input_grid = test.get('input', [])
test_output_grid = solutions[i] if solutions else []
test_input_tensor = torch.tensor(test_input_grid, dtype=torch.long)
test_output_tensor = torch.tensor(test_output_grid, dtype=torch.long)
test_inputs.append(test_input_tensor)
test_outputs.append(test_output_tensor)
# test_output_list = solutions_dict.get(task_id, None)
# test_output = test_output_list[0] if test_output_list else []
# test_input_tensor = torch.tensor(test_input, dtype=torch.long)
# test_output_tensor = torch.tensor(test_output, dtype=torch.long)
# task['test_input'] = test_input_tensor
# task['test_output'] = test_output_tensor
task['test_inputs'] = test_inputs
task['test_outputs'] = test_outputs
self.tasks.append(task)
def __len__(self):
return len(self.tasks)
def __getitem__(self, idx):
return self.tasks[idx]
def collate_fn(batch):
return batch[0]
class ARCDataLoader:
def __init__(self, params):
# Set data path
self.data_path = params.get('input_path', './data') + '/arc-prize-2024'
self.batch_size = params.get('batch_size', 1)
# File paths
self.train_challenges_file = os.path.join(self.data_path, 'arc-agi_training_challenges.json')
self.train_solutions_file = os.path.join(self.data_path, 'arc-agi_training_solutions.json')
self.eval_challenges_file = os.path.join(self.data_path, 'arc-agi_evaluation_challenges.json')
self.eval_solutions_file = os.path.join(self.data_path, 'arc-agi_evaluation_solutions.json')
self.test_challenges_file = os.path.join(self.data_path, 'arc-agi_test_challenges.json')
# No solutions file for test set
self.train_dataset = ARCDataset(self.train_challenges_file, self.train_solutions_file)
self.train_loader = DataLoader(dataset=self.train_dataset, batch_size=self.batch_size, shuffle=False, collate_fn=collate_fn)
self.eval_dataset = ARCDataset(self.eval_challenges_file, self.eval_solutions_file)
self.eval_loader = DataLoader(dataset=self.eval_dataset, batch_size=self.batch_size, shuffle=False, collate_fn=collate_fn)
self.test_dataset = ARCDataset(self.test_challenges_file)
self.test_loader = DataLoader(dataset=self.test_dataset, batch_size=self.batch_size, shuffle=False, collate_fn=collate_fn)
def get_train_loader(self):
return self.train_loader
def get_eval_loader(self):
return self.eval_loader
def get_test_loader(self):
return self.test_loader
def get_data_loader(self, mode='train'):
if mode == 'train':
return self.get_train_loader()
elif mode == 'eval':
return self.get_eval_loader()
elif mode == 'test':
return self.get_test_loader()
else:
raise ValueError("Invalid mode. Choose from 'train', 'eval', or 'test'.")
def print_data_dirs(self):
for dirname, _, filenames in os.walk(self.data_path):
for filename in filenames:
print(os.path.join(dirname, filename))
def get_task(self, task_index, task_id=None, mode='train'):
if mode == 'train':
dataset = self.train_dataset
elif mode == 'eval':
dataset = self.eval_dataset
elif mode == 'test':
dataset = self.test_dataset
else:
raise ValueError("Invalid mode. Choose from 'train', 'eval', or 'test'.")
if task_id is not None:
for task in dataset:
if task['task_id'] == task_id:
return task
raise ValueError(f"Task with task_id {task_id} not found in {mode} dataset.")
if task_index is not None:
return dataset[task_index]
raise ValueError("Either task_index or task_id must be provided.")
def print_task(self, task_num = 0, mode = 'train'):
if mode == 'train':
task = self.train_dataset[task_num]
elif mode == 'eval':
task = self.eval_dataset[task_num]
elif mode == 'test':
task = self.test_dataset[task_num]
else:
raise ValueError("Invalid mode. Choose from 'train', 'eval', or 'test'.")
print(f"first task in {mode} set:")
print(f"task_id: {task['task_id']}")
demos = zip(task['demo_inputs'], task['demo_outputs'])
print("\nDEMOS:\n")
for idx, demo in enumerate(demos):
print(f"demo_input {idx}: \n{demo[0]}")
print(f"demo_input shape {idx}:", demo[0].shape)
print(f"demo_output {idx}: \n{demo[1]}")
print(f"demo_output shape {idx}:", demo[1].shape)
tests = zip(task['test_inputs'], task['test_outputs'])
print("\nTESTS:\n")
for idx, test in enumerate(tests):
print(f"test_input {idx}: \n{test[0]}")
print(f"test_input shape {idx}:", test[0].shape)
print(f"test_output {idx}: \n{test[1]}")
print(f"test_output shape {idx}:", test[1].shape)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment