Created
March 26, 2025 17:41
-
-
Save fffej/654ddd6bec9556ddc636a9dc1eca3009 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import numpy as np | |
| import tensorflow as tf | |
| from tensorflow.keras.models import Sequential | |
| from tensorflow.keras.layers import Dense, Dropout | |
| from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint | |
| import matplotlib.pyplot as plt | |
| from sklearn.metrics import confusion_matrix, classification_report | |
| import seaborn as sns | |
| # Load the preprocessed data | |
| print("Loading preprocessed data...") | |
| use_cnn = True | |
| suffix = "-cnn" if use_cnn else "" | |
| X_train = np.load(f'X_train{suffix}.npy') | |
| X_test = np.load(f'X_test{suffix}.npy') | |
| X_val = np.load(f'X_val{suffix}.npy') | |
| y_train = np.load(f'y_train{suffix}.npy') | |
| y_test = np.load(f'y_test{suffix}.npy') | |
| y_val = np.load(f'y_val{suffix}.npy') | |
| # Print shapes to confirm data loading | |
| print(f"Training data shape: {X_train.shape}, {y_train.shape}") | |
| print(f"Test data shape: {X_test.shape}, {y_test.shape}") | |
| print(f"Validation data shape: {X_val.shape}, {y_val.shape}") | |
| # Define the model | |
| def create_model(input_shape, num_classes): | |
| model = Sequential([ | |
| # Input layer | |
| Dense(256, activation='relu', input_shape=(input_shape,)), | |
| Dropout(0.3), | |
| # Hidden layers | |
| Dense(128, activation='relu'), | |
| Dropout(0.3), | |
| Dense(64, activation='relu'), | |
| Dropout(0.2), | |
| # Output layer | |
| Dense(num_classes, activation='softmax') | |
| ]) | |
| # Compile the model | |
| model.compile( | |
| optimizer='adam', | |
| loss='categorical_crossentropy', | |
| metrics=['accuracy'] | |
| ) | |
| return model | |
| # Create and train the model | |
| input_shape = X_train.shape[1] # 4096 | |
| num_classes = y_train.shape[1] # 5 | |
| model = create_model(input_shape, num_classes) | |
| model.summary() | |
| # Define callbacks | |
| early_stopping = EarlyStopping( | |
| monitor='val_loss', | |
| patience=10, | |
| restore_best_weights=True | |
| ) | |
| checkpoint = ModelCheckpoint( | |
| 'chess_piece_model.h5', | |
| monitor='val_accuracy', | |
| save_best_only=True, | |
| verbose=1 | |
| ) | |
| # Train the model | |
| history = model.fit( | |
| X_train, y_train, | |
| epochs=500, | |
| batch_size=60, | |
| validation_data=(X_val, y_val), | |
| callbacks=[early_stopping], | |
| verbose=1 | |
| ) | |
| # Evaluate the model | |
| test_loss, test_accuracy = model.evaluate(X_test, y_test) | |
| print(f"Test accuracy: {test_accuracy:.4f}") | |
| print(f"Test loss: {test_loss:.4f}") | |
| # Visualize training history | |
| def plot_history(history): | |
| plt.figure(figsize=(12, 5)) | |
| # Plot accuracy | |
| plt.subplot(1, 2, 1) | |
| plt.plot(history.history['accuracy'], label='Training Accuracy') | |
| plt.plot(history.history['val_accuracy'], label='Validation Accuracy') | |
| plt.title('Model Accuracy') | |
| plt.xlabel('Epoch') | |
| plt.ylabel('Accuracy') | |
| plt.legend() | |
| # Plot loss | |
| plt.subplot(1, 2, 2) | |
| plt.plot(history.history['loss'], label='Training Loss') | |
| plt.plot(history.history['val_loss'], label='Validation Loss') | |
| plt.title('Model Loss') | |
| plt.xlabel('Epoch') | |
| plt.ylabel('Loss') | |
| plt.legend() | |
| plt.tight_layout() | |
| plt.savefig('split_training_history.png') | |
| plt.close() | |
| plot_history(history) | |
| # Make predictions on the test set | |
| y_pred = model.predict(X_test) | |
| y_pred_classes = np.argmax(y_pred, axis=1) | |
| y_true_classes = np.argmax(y_test, axis=1) | |
| # Generate confusion matrix | |
| cm = confusion_matrix(y_true_classes, y_pred_classes) | |
| class_names = ['Pawn', 'Knight', 'Rook', 'Queen', 'Bishop'] | |
| # Plot confusion matrix | |
| plt.figure(figsize=(10, 8)) | |
| sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=class_names, yticklabels=class_names) | |
| plt.xlabel('Predicted') | |
| plt.ylabel('True') | |
| plt.title('Confusion Matrix') | |
| plt.savefig('split_confusion_matrix.png') | |
| plt.close() | |
| # Print classification report | |
| print("\nClassification Report:") | |
| print(classification_report(y_true_classes, y_pred_classes, target_names=class_names)) | |
| # Visualize some predictions | |
| def visualize_predictions(X, y_true, y_pred, class_names, num_samples=5): | |
| plt.figure(figsize=(15, 10)) | |
| # Randomly select samples | |
| indices = np.random.choice(len(X), num_samples*len(class_names), replace=False) | |
| for i, idx in enumerate(indices): | |
| plt.subplot(len(class_names), num_samples, i + 1) | |
| plt.imshow(X[idx].reshape(64, 64), cmap='binary') | |
| true_class = class_names[np.argmax(y_true[idx])] | |
| pred_class = class_names[np.argmax(y_pred[idx])] | |
| color = 'green' if true_class == pred_class else 'red' | |
| plt.title(f"True: {true_class}\nPred: {pred_class}", color=color) | |
| plt.axis('off') | |
| plt.tight_layout() | |
| plt.savefig('split_prediction_samples.png') | |
| plt.close() | |
| # Visualize predictions | |
| if not use_cnn: | |
| visualize_predictions(X_test, y_test, y_pred, class_names) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment