Skip to content

Instantly share code, notes, and snippets.

@fffej
Created March 26, 2025 17:41
Show Gist options
  • Save fffej/654ddd6bec9556ddc636a9dc1eca3009 to your computer and use it in GitHub Desktop.
Save fffej/654ddd6bec9556ddc636a9dc1eca3009 to your computer and use it in GitHub Desktop.
import numpy as np
import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Dropout
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix, classification_report
import seaborn as sns
# Load the preprocessed data
print("Loading preprocessed data...")
use_cnn = True
suffix = "-cnn" if use_cnn else ""
X_train = np.load(f'X_train{suffix}.npy')
X_test = np.load(f'X_test{suffix}.npy')
X_val = np.load(f'X_val{suffix}.npy')
y_train = np.load(f'y_train{suffix}.npy')
y_test = np.load(f'y_test{suffix}.npy')
y_val = np.load(f'y_val{suffix}.npy')
# Print shapes to confirm data loading
print(f"Training data shape: {X_train.shape}, {y_train.shape}")
print(f"Test data shape: {X_test.shape}, {y_test.shape}")
print(f"Validation data shape: {X_val.shape}, {y_val.shape}")
# Define the model
def create_model(input_shape, num_classes):
model = Sequential([
# Input layer
Dense(256, activation='relu', input_shape=(input_shape,)),
Dropout(0.3),
# Hidden layers
Dense(128, activation='relu'),
Dropout(0.3),
Dense(64, activation='relu'),
Dropout(0.2),
# Output layer
Dense(num_classes, activation='softmax')
])
# Compile the model
model.compile(
optimizer='adam',
loss='categorical_crossentropy',
metrics=['accuracy']
)
return model
# Create and train the model
input_shape = X_train.shape[1] # 4096
num_classes = y_train.shape[1] # 5
model = create_model(input_shape, num_classes)
model.summary()
# Define callbacks
early_stopping = EarlyStopping(
monitor='val_loss',
patience=10,
restore_best_weights=True
)
checkpoint = ModelCheckpoint(
'chess_piece_model.h5',
monitor='val_accuracy',
save_best_only=True,
verbose=1
)
# Train the model
history = model.fit(
X_train, y_train,
epochs=500,
batch_size=60,
validation_data=(X_val, y_val),
callbacks=[early_stopping],
verbose=1
)
# Evaluate the model
test_loss, test_accuracy = model.evaluate(X_test, y_test)
print(f"Test accuracy: {test_accuracy:.4f}")
print(f"Test loss: {test_loss:.4f}")
# Visualize training history
def plot_history(history):
plt.figure(figsize=(12, 5))
# Plot accuracy
plt.subplot(1, 2, 1)
plt.plot(history.history['accuracy'], label='Training Accuracy')
plt.plot(history.history['val_accuracy'], label='Validation Accuracy')
plt.title('Model Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend()
# Plot loss
plt.subplot(1, 2, 2)
plt.plot(history.history['loss'], label='Training Loss')
plt.plot(history.history['val_loss'], label='Validation Loss')
plt.title('Model Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.tight_layout()
plt.savefig('split_training_history.png')
plt.close()
plot_history(history)
# Make predictions on the test set
y_pred = model.predict(X_test)
y_pred_classes = np.argmax(y_pred, axis=1)
y_true_classes = np.argmax(y_test, axis=1)
# Generate confusion matrix
cm = confusion_matrix(y_true_classes, y_pred_classes)
class_names = ['Pawn', 'Knight', 'Rook', 'Queen', 'Bishop']
# Plot confusion matrix
plt.figure(figsize=(10, 8))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=class_names, yticklabels=class_names)
plt.xlabel('Predicted')
plt.ylabel('True')
plt.title('Confusion Matrix')
plt.savefig('split_confusion_matrix.png')
plt.close()
# Print classification report
print("\nClassification Report:")
print(classification_report(y_true_classes, y_pred_classes, target_names=class_names))
# Visualize some predictions
def visualize_predictions(X, y_true, y_pred, class_names, num_samples=5):
plt.figure(figsize=(15, 10))
# Randomly select samples
indices = np.random.choice(len(X), num_samples*len(class_names), replace=False)
for i, idx in enumerate(indices):
plt.subplot(len(class_names), num_samples, i + 1)
plt.imshow(X[idx].reshape(64, 64), cmap='binary')
true_class = class_names[np.argmax(y_true[idx])]
pred_class = class_names[np.argmax(y_pred[idx])]
color = 'green' if true_class == pred_class else 'red'
plt.title(f"True: {true_class}\nPred: {pred_class}", color=color)
plt.axis('off')
plt.tight_layout()
plt.savefig('split_prediction_samples.png')
plt.close()
# Visualize predictions
if not use_cnn:
visualize_predictions(X_test, y_test, y_pred, class_names)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment