Skip to content

Instantly share code, notes, and snippets.

@devonartis
Created November 1, 2024 03:55
Show Gist options
  • Save devonartis/a308e0afb96b2daa4ca1a429bda1432f to your computer and use it in GitHub Desktop.
Save devonartis/a308e0afb96b2daa4ca1a429bda1432f to your computer and use it in GitHub Desktop.
Hate Content Detection Model Training Script This script automates the process of training a hate content detection model using TensorFlow and KMeans clustering. It includes several key components: Device Setup: Checks available GPUs for optimized processing. Data Processing: Loads and preprocesses image data from specified training and test fol…
"""
Training script for hate content detection model
"""
import pandas as pd
import tensorflow as tf
from utils import ImageProcessor, SimpleMetalKMeans
def train():
"""Main training function"""
print("Starting hate content detection model training...")
# Print TensorFlow device info
physical_devices = tf.config.list_physical_devices('GPU')
print("Available devices:", physical_devices)
# Initialize processor
processor = ImageProcessor()
try:
# Load and process training data
print("\nLoading training data...")
train_folder = './Training_data'
train_images, train_ids = processor.load_images(train_folder)
# Preprocess data
print("\nPreprocessing training data...")
train_data = train_images.reshape(train_images.shape[0], -1) / 255.0
print(f"Training data shape: {train_data.shape}")
# Initialize and train model
print("\nInitializing model...")
model = SimpleMetalKMeans(
n_clusters=2,
random_state=42,
batch_size=1024
)
print("\nTraining model...")
model.fit(train_data)
# Save the trained model
print("\nSaving trained model...")
model.save()
# Process test data
print("\nProcessing test data...")
test_folder = './Test_data'
test_images, test_ids = processor.load_images(test_folder)
test_data = test_images.reshape(test_images.shape[0], -1) / 255.0
# Make predictions
print("\nMaking predictions...")
predictions = model.predict(test_data)
# Save results
print("\nSaving results...")
results_df = pd.DataFrame({
'image_id': test_ids,
'prediction': predictions
})
results_df.to_csv('predictions.csv', index=False)
print("\nPredictions saved to 'predictions.csv'")
except Exception as e:
print(f"\nError during training: {e}")
raise
if __name__ == "__main__":
train()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment