Created
November 16, 2024 23:57
-
-
Save Sinjhin/4c6699cc529ed3a9286ea2c2e01946f4 to your computer and use it in GitHub Desktop.
ARC-AGI Loader
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import json | |
import torch | |
import numpy as np | |
from torch.utils.data import Dataset, DataLoader | |
class ARCDataset(Dataset): | |
def __init__(self, challenges_file, solutions_file=None): | |
self.tasks = [] | |
# Load challenges | |
with open(challenges_file, 'r') as f: | |
challenges = json.load(f) | |
# Load solutions (if provided) | |
if solutions_file: | |
with open(solutions_file, 'r') as f: | |
solutions = json.load(f) | |
# Solutions is a dictionary with task IDs as keys | |
solutions_dict = solutions | |
else: | |
solutions_dict = {} | |
# Iterate over tasks | |
for task_id, task_data in challenges.items(): | |
task = {'task_id': task_id} | |
# Process training examples | |
demo_examples = task_data.get('train', []) | |
demo_inputs = [] | |
demo_outputs = [] | |
for example in demo_examples: | |
input_grid = example['input'] | |
output_grid = example['output'] | |
input_tensor = torch.tensor(input_grid, dtype=torch.long) | |
output_tensor = torch.tensor(output_grid, dtype=torch.long) | |
demo_inputs.append(input_tensor) | |
demo_outputs.append(output_tensor) | |
task['demo_inputs'] = demo_inputs | |
task['demo_outputs'] = demo_outputs | |
# Process evaluation (test) examples | |
# test_input = task_data.get('test', [])[0].get('input', []) | |
tests = task_data.get('test', []) | |
solutions = solutions_dict.get(task_id, None) | |
test_inputs = [] | |
test_outputs = [] | |
for i, test in enumerate(tests): | |
test_input_grid = test.get('input', []) | |
test_output_grid = solutions[i] if solutions else [] | |
test_input_tensor = torch.tensor(test_input_grid, dtype=torch.long) | |
test_output_tensor = torch.tensor(test_output_grid, dtype=torch.long) | |
test_inputs.append(test_input_tensor) | |
test_outputs.append(test_output_tensor) | |
# test_output_list = solutions_dict.get(task_id, None) | |
# test_output = test_output_list[0] if test_output_list else [] | |
# test_input_tensor = torch.tensor(test_input, dtype=torch.long) | |
# test_output_tensor = torch.tensor(test_output, dtype=torch.long) | |
# task['test_input'] = test_input_tensor | |
# task['test_output'] = test_output_tensor | |
task['test_inputs'] = test_inputs | |
task['test_outputs'] = test_outputs | |
self.tasks.append(task) | |
def __len__(self): | |
return len(self.tasks) | |
def __getitem__(self, idx): | |
return self.tasks[idx] | |
def collate_fn(batch): | |
return batch[0] | |
class ARCDataLoader: | |
def __init__(self, params): | |
# Set data path | |
self.data_path = params.get('input_path', './data') + '/arc-prize-2024' | |
self.batch_size = params.get('batch_size', 1) | |
# File paths | |
self.train_challenges_file = os.path.join(self.data_path, 'arc-agi_training_challenges.json') | |
self.train_solutions_file = os.path.join(self.data_path, 'arc-agi_training_solutions.json') | |
self.eval_challenges_file = os.path.join(self.data_path, 'arc-agi_evaluation_challenges.json') | |
self.eval_solutions_file = os.path.join(self.data_path, 'arc-agi_evaluation_solutions.json') | |
self.test_challenges_file = os.path.join(self.data_path, 'arc-agi_test_challenges.json') | |
# No solutions file for test set | |
self.train_dataset = ARCDataset(self.train_challenges_file, self.train_solutions_file) | |
self.train_loader = DataLoader(dataset=self.train_dataset, batch_size=self.batch_size, shuffle=False, collate_fn=collate_fn) | |
self.eval_dataset = ARCDataset(self.eval_challenges_file, self.eval_solutions_file) | |
self.eval_loader = DataLoader(dataset=self.eval_dataset, batch_size=self.batch_size, shuffle=False, collate_fn=collate_fn) | |
self.test_dataset = ARCDataset(self.test_challenges_file) | |
self.test_loader = DataLoader(dataset=self.test_dataset, batch_size=self.batch_size, shuffle=False, collate_fn=collate_fn) | |
def get_train_loader(self): | |
return self.train_loader | |
def get_eval_loader(self): | |
return self.eval_loader | |
def get_test_loader(self): | |
return self.test_loader | |
def get_data_loader(self, mode='train'): | |
if mode == 'train': | |
return self.get_train_loader() | |
elif mode == 'eval': | |
return self.get_eval_loader() | |
elif mode == 'test': | |
return self.get_test_loader() | |
else: | |
raise ValueError("Invalid mode. Choose from 'train', 'eval', or 'test'.") | |
def print_data_dirs(self): | |
for dirname, _, filenames in os.walk(self.data_path): | |
for filename in filenames: | |
print(os.path.join(dirname, filename)) | |
def get_task(self, task_index, task_id=None, mode='train'): | |
if mode == 'train': | |
dataset = self.train_dataset | |
elif mode == 'eval': | |
dataset = self.eval_dataset | |
elif mode == 'test': | |
dataset = self.test_dataset | |
else: | |
raise ValueError("Invalid mode. Choose from 'train', 'eval', or 'test'.") | |
if task_id is not None: | |
for task in dataset: | |
if task['task_id'] == task_id: | |
return task | |
raise ValueError(f"Task with task_id {task_id} not found in {mode} dataset.") | |
if task_index is not None: | |
return dataset[task_index] | |
raise ValueError("Either task_index or task_id must be provided.") | |
def print_task(self, task_num = 0, mode = 'train'): | |
if mode == 'train': | |
task = self.train_dataset[task_num] | |
elif mode == 'eval': | |
task = self.eval_dataset[task_num] | |
elif mode == 'test': | |
task = self.test_dataset[task_num] | |
else: | |
raise ValueError("Invalid mode. Choose from 'train', 'eval', or 'test'.") | |
print(f"first task in {mode} set:") | |
print(f"task_id: {task['task_id']}") | |
demos = zip(task['demo_inputs'], task['demo_outputs']) | |
print("\nDEMOS:\n") | |
for idx, demo in enumerate(demos): | |
print(f"demo_input {idx}: \n{demo[0]}") | |
print(f"demo_input shape {idx}:", demo[0].shape) | |
print(f"demo_output {idx}: \n{demo[1]}") | |
print(f"demo_output shape {idx}:", demo[1].shape) | |
tests = zip(task['test_inputs'], task['test_outputs']) | |
print("\nTESTS:\n") | |
for idx, test in enumerate(tests): | |
print(f"test_input {idx}: \n{test[0]}") | |
print(f"test_input shape {idx}:", test[0].shape) | |
print(f"test_output {idx}: \n{test[1]}") | |
print(f"test_output shape {idx}:", test[1].shape) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment