Created
March 26, 2025 17:39
-
-
Save fffej/79d21b4fdb541d6e2deb534bafbeaefa to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import os | |
| import numpy as np | |
| from PIL import Image | |
| from sklearn.model_selection import train_test_split | |
| import matplotlib.pyplot as plt | |
| # Define the directories and their corresponding labels | |
| directories = { | |
| 'data/pawn_resized-BW': [1, 0, 0, 0, 0], | |
| 'data/bishop_resized-BW': [0, 0, 0, 0, 1], | |
| 'data/knight-resize-BW': [0, 1, 0, 0, 0], | |
| 'data/Rook-resize-BW': [0, 0, 1, 0, 0], | |
| 'data/Queen-Resized-BW': [0, 0, 0, 1, 0] | |
| } | |
| def load_and_preprocess_images(directory_dict): | |
| """Load images from directories and preprocess them into binary vectors.""" | |
| X = [] # Image data | |
| y = [] # Labels | |
| # For visualization | |
| class_counts = {} | |
| for directory, label in directory_dict.items(): | |
| files = os.listdir(directory) | |
| png_files = [f for f in files if f.lower().endswith('.png')] | |
| # Store count for visualization | |
| class_name = os.path.basename(directory).split('-')[0].split('_')[0] | |
| class_counts[class_name] = len(png_files) | |
| print(f"Loading {len(png_files)} images from {directory}") | |
| for img_file in png_files: | |
| img_path = os.path.join(directory, img_file) | |
| # Open the image | |
| img = Image.open(img_path).convert('L') # Convert to grayscale | |
| # Convert to numpy array | |
| img_array = np.array(img) | |
| # Convert to binary | |
| binary_array = (img_array > 128).astype(int) | |
| # Flatten the 64x64 image into a 4096 vector | |
| flat_array = binary_array.flatten() | |
| # Add to dataset | |
| X.append(flat_array) | |
| y.append(label) | |
| # Convert to numpy arrays | |
| X = np.array(X) | |
| y = np.array(y) | |
| # Visualize class distribution | |
| plt.figure(figsize=(10, 5)) | |
| plt.bar(class_counts.keys(), class_counts.values()) | |
| plt.title('Number of Images per Chess Piece Class') | |
| plt.xlabel('Chess Piece') | |
| plt.ylabel('Number of Images') | |
| plt.savefig('class_distribution.png') | |
| plt.close() | |
| return X, y | |
| def split_data(X, y, train_size=0.6, test_size=0.2, val_size=0.2): | |
| """Split data into training, test, and validation sets.""" | |
| # First split: training and the rest | |
| X_train, X_temp, y_train, y_temp = train_test_split( | |
| X, y, train_size=train_size, stratify=y, random_state=42 | |
| ) | |
| # Second split: test and validation from the remaining data | |
| # Calculate the proportion of test data relative to the remaining data | |
| test_proportion = test_size / (test_size + val_size) | |
| X_test, X_val, y_test, y_val = train_test_split( | |
| X_temp, y_temp, train_size=test_proportion, stratify=y_temp, random_state=42 | |
| ) | |
| return X_train, X_test, X_val, y_train, y_test, y_val | |
| def main(): | |
| # Load and preprocess the images | |
| print("Loading and preprocessing images...") | |
| X, y = load_and_preprocess_images(directories) | |
| # Split the data | |
| print("Splitting data into training, test, and validation sets...") | |
| X_train, X_test, X_val, y_train, y_test, y_val = split_data(X, y) | |
| # Display information about the datasets | |
| print(f"Total number of samples: {len(X)}") | |
| print(f"Training set: {len(X_train)} samples") | |
| print(f"Test set: {len(X_test)} samples") | |
| print(f"Validation set: {len(X_val)} samples") | |
| # Check the distribution of classes in each set | |
| train_distribution = np.sum(y_train, axis=0) | |
| test_distribution = np.sum(y_test, axis=0) | |
| val_distribution = np.sum(y_val, axis=0) | |
| print("\nClass distribution:") | |
| print(f"Training set: {train_distribution}") | |
| print(f"Test set: {test_distribution}") | |
| print(f"Validation set: {val_distribution}") | |
| print("\nSaving preprocessed data...") | |
| np.save('X_train.npy', X_train) | |
| np.save('X_test.npy', X_test) | |
| np.save('X_val.npy', X_val) | |
| np.save('y_train.npy', y_train) | |
| np.save('y_test.npy', y_test) | |
| np.save('y_val.npy', y_val) | |
| # Visualize a few samples from each class | |
| visualize_samples(X, y) | |
| def visualize_samples(X, y): | |
| """Visualize a few samples from each class.""" | |
| class_names = ['Pawn', 'Knight', 'Rook', 'Queen', 'Bishop'] | |
| plt.figure(figsize=(15, 10)) | |
| for i, class_idx in enumerate(range(5)): | |
| # Find samples of this class | |
| indices = np.where(np.argmax(y, axis=1) == class_idx)[0] | |
| # Display up to 5 samples | |
| for j in range(min(5, len(indices))): | |
| if j < len(indices): | |
| plt.subplot(5, 5, i*5 + j + 1) | |
| plt.imshow(X[indices[j]].reshape(64, 64), cmap='binary') | |
| plt.title(f"{class_names[class_idx]}") | |
| plt.axis('off') | |
| plt.tight_layout() | |
| plt.savefig('sample_images.png') | |
| plt.close() | |
| if __name__ == "__main__": | |
| main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment