Skip to content

Instantly share code, notes, and snippets.

@fffej
Created March 26, 2025 17:39
Show Gist options
  • Save fffej/79d21b4fdb541d6e2deb534bafbeaefa to your computer and use it in GitHub Desktop.
Save fffej/79d21b4fdb541d6e2deb534bafbeaefa to your computer and use it in GitHub Desktop.
import os
import numpy as np
from PIL import Image
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
# Define the directories and their corresponding labels
directories = {
'data/pawn_resized-BW': [1, 0, 0, 0, 0],
'data/bishop_resized-BW': [0, 0, 0, 0, 1],
'data/knight-resize-BW': [0, 1, 0, 0, 0],
'data/Rook-resize-BW': [0, 0, 1, 0, 0],
'data/Queen-Resized-BW': [0, 0, 0, 1, 0]
}
def load_and_preprocess_images(directory_dict):
"""Load images from directories and preprocess them into binary vectors."""
X = [] # Image data
y = [] # Labels
# For visualization
class_counts = {}
for directory, label in directory_dict.items():
files = os.listdir(directory)
png_files = [f for f in files if f.lower().endswith('.png')]
# Store count for visualization
class_name = os.path.basename(directory).split('-')[0].split('_')[0]
class_counts[class_name] = len(png_files)
print(f"Loading {len(png_files)} images from {directory}")
for img_file in png_files:
img_path = os.path.join(directory, img_file)
# Open the image
img = Image.open(img_path).convert('L') # Convert to grayscale
# Convert to numpy array
img_array = np.array(img)
# Convert to binary
binary_array = (img_array > 128).astype(int)
# Flatten the 64x64 image into a 4096 vector
flat_array = binary_array.flatten()
# Add to dataset
X.append(flat_array)
y.append(label)
# Convert to numpy arrays
X = np.array(X)
y = np.array(y)
# Visualize class distribution
plt.figure(figsize=(10, 5))
plt.bar(class_counts.keys(), class_counts.values())
plt.title('Number of Images per Chess Piece Class')
plt.xlabel('Chess Piece')
plt.ylabel('Number of Images')
plt.savefig('class_distribution.png')
plt.close()
return X, y
def split_data(X, y, train_size=0.6, test_size=0.2, val_size=0.2):
"""Split data into training, test, and validation sets."""
# First split: training and the rest
X_train, X_temp, y_train, y_temp = train_test_split(
X, y, train_size=train_size, stratify=y, random_state=42
)
# Second split: test and validation from the remaining data
# Calculate the proportion of test data relative to the remaining data
test_proportion = test_size / (test_size + val_size)
X_test, X_val, y_test, y_val = train_test_split(
X_temp, y_temp, train_size=test_proportion, stratify=y_temp, random_state=42
)
return X_train, X_test, X_val, y_train, y_test, y_val
def main():
# Load and preprocess the images
print("Loading and preprocessing images...")
X, y = load_and_preprocess_images(directories)
# Split the data
print("Splitting data into training, test, and validation sets...")
X_train, X_test, X_val, y_train, y_test, y_val = split_data(X, y)
# Display information about the datasets
print(f"Total number of samples: {len(X)}")
print(f"Training set: {len(X_train)} samples")
print(f"Test set: {len(X_test)} samples")
print(f"Validation set: {len(X_val)} samples")
# Check the distribution of classes in each set
train_distribution = np.sum(y_train, axis=0)
test_distribution = np.sum(y_test, axis=0)
val_distribution = np.sum(y_val, axis=0)
print("\nClass distribution:")
print(f"Training set: {train_distribution}")
print(f"Test set: {test_distribution}")
print(f"Validation set: {val_distribution}")
print("\nSaving preprocessed data...")
np.save('X_train.npy', X_train)
np.save('X_test.npy', X_test)
np.save('X_val.npy', X_val)
np.save('y_train.npy', y_train)
np.save('y_test.npy', y_test)
np.save('y_val.npy', y_val)
# Visualize a few samples from each class
visualize_samples(X, y)
def visualize_samples(X, y):
"""Visualize a few samples from each class."""
class_names = ['Pawn', 'Knight', 'Rook', 'Queen', 'Bishop']
plt.figure(figsize=(15, 10))
for i, class_idx in enumerate(range(5)):
# Find samples of this class
indices = np.where(np.argmax(y, axis=1) == class_idx)[0]
# Display up to 5 samples
for j in range(min(5, len(indices))):
if j < len(indices):
plt.subplot(5, 5, i*5 + j + 1)
plt.imshow(X[indices[j]].reshape(64, 64), cmap='binary')
plt.title(f"{class_names[class_idx]}")
plt.axis('off')
plt.tight_layout()
plt.savefig('sample_images.png')
plt.close()
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment