Created
November 1, 2024 03:55
-
-
Save devonartis/a308e0afb96b2daa4ca1a429bda1432f to your computer and use it in GitHub Desktop.
Hate Content Detection Model Training Script This script automates the process of training a hate content detection model using TensorFlow and KMeans clustering. It includes several key components: Device Setup: Checks available GPUs for optimized processing. Data Processing: Loads and preprocesses image data from specified training and test fol…
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Training script for hate content detection model | |
""" | |
import pandas as pd | |
import tensorflow as tf | |
from utils import ImageProcessor, SimpleMetalKMeans | |
def train(): | |
"""Main training function""" | |
print("Starting hate content detection model training...") | |
# Print TensorFlow device info | |
physical_devices = tf.config.list_physical_devices('GPU') | |
print("Available devices:", physical_devices) | |
# Initialize processor | |
processor = ImageProcessor() | |
try: | |
# Load and process training data | |
print("\nLoading training data...") | |
train_folder = './Training_data' | |
train_images, train_ids = processor.load_images(train_folder) | |
# Preprocess data | |
print("\nPreprocessing training data...") | |
train_data = train_images.reshape(train_images.shape[0], -1) / 255.0 | |
print(f"Training data shape: {train_data.shape}") | |
# Initialize and train model | |
print("\nInitializing model...") | |
model = SimpleMetalKMeans( | |
n_clusters=2, | |
random_state=42, | |
batch_size=1024 | |
) | |
print("\nTraining model...") | |
model.fit(train_data) | |
# Save the trained model | |
print("\nSaving trained model...") | |
model.save() | |
# Process test data | |
print("\nProcessing test data...") | |
test_folder = './Test_data' | |
test_images, test_ids = processor.load_images(test_folder) | |
test_data = test_images.reshape(test_images.shape[0], -1) / 255.0 | |
# Make predictions | |
print("\nMaking predictions...") | |
predictions = model.predict(test_data) | |
# Save results | |
print("\nSaving results...") | |
results_df = pd.DataFrame({ | |
'image_id': test_ids, | |
'prediction': predictions | |
}) | |
results_df.to_csv('predictions.csv', index=False) | |
print("\nPredictions saved to 'predictions.csv'") | |
except Exception as e: | |
print(f"\nError during training: {e}") | |
raise | |
if __name__ == "__main__": | |
train() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment