Created
October 23, 2024 21:33
-
-
Save alexalbertt/0a85106e903a7febbc4491ebf71b9d6b to your computer and use it in GitHub Desktop.
Classifier script
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import pandas as pd | |
import numpy as np | |
from sklearn.model_selection import train_test_split | |
from sklearn.preprocessing import StandardScaler | |
from sklearn.ensemble import RandomForestClassifier | |
from sklearn.metrics import classification_report, confusion_matrix | |
import matplotlib.pyplot as plt | |
import seaborn as sns | |
# Load and preprocess the data | |
def load_and_preprocess_data(): | |
# Load the data | |
df = pd.read_csv('~/Downloads/Air_Quality.csv') | |
# Focus on PM 2.5 data | |
pm25_data = df[df['Name'] == 'Fine particles (PM 2.5)'].copy() | |
# Create binary classification target (1 if PM2.5 > median, 0 otherwise) | |
median_pm25 = pm25_data['Data Value'].median() | |
pm25_data['High_Pollution'] = (pm25_data['Data Value'] > median_pm25).astype(int) | |
# Extract year from Start_Date | |
pm25_data['Year'] = pd.to_datetime(pm25_data['Start_Date']).dt.year | |
return pm25_data | |
def prepare_features(pm25_data): | |
# Create features using one-hot encoding | |
features = pd.get_dummies(pm25_data[['Geo Type Name', 'Year']], drop_first=True) | |
# Split data | |
X = features | |
y = pm25_data['High_Pollution'] | |
return train_test_split(X, y, test_size=0.2, random_state=42) | |
def train_model(X_train, X_test, y_train, y_test): | |
# Scale features | |
scaler = StandardScaler() | |
X_train_scaled = scaler.fit_transform(X_train) | |
X_test_scaled = scaler.transform(X_test) | |
# Train classifier | |
clf = RandomForestClassifier(random_state=42) | |
clf.fit(X_train_scaled, y_train) | |
return clf, X_test_scaled | |
def evaluate_model(clf, X_test_scaled, y_test, features): | |
# Make predictions | |
y_pred = clf.predict(X_test_scaled) | |
# Generate and print classification report | |
report = classification_report(y_test, y_pred) | |
print("\nClassification Report:") | |
print(report) | |
# Create confusion matrix plot | |
plt.figure(figsize=(10, 8)) | |
cm = confusion_matrix(y_test, y_pred) | |
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues') | |
plt.title('Confusion Matrix') | |
plt.ylabel('True Label') | |
plt.xlabel('Predicted Label') | |
plt.savefig('confusion_matrix.png') | |
plt.close() | |
# Create feature importance plot | |
importance = pd.DataFrame({ | |
'feature': features.columns, | |
'importance': clf.feature_importances_ | |
}).sort_values('importance', ascending=False) | |
plt.figure(figsize=(12, 6)) | |
sns.barplot(data=importance.head(10), x='importance', y='feature') | |
plt.title('Top 10 Most Important Features') | |
plt.xlabel('Feature Importance') | |
plt.tight_layout() | |
plt.savefig('feature_importance.png') | |
plt.close() | |
def plot_time_series(pm25_data): | |
# Create time series plot of PM2.5 levels | |
plt.figure(figsize=(12, 6)) | |
yearly_avg = pm25_data.groupby('Year')['Data Value'].mean() | |
plt.plot(yearly_avg.index, yearly_avg.values, marker='o') | |
plt.title('Average PM2.5 Levels Over Time') | |
plt.xlabel('Year') | |
plt.ylabel('PM2.5 Level') | |
plt.grid(True) | |
plt.savefig('pm25_trend.png') | |
plt.close() | |
def main(): | |
# Execute the analysis pipeline | |
pm25_data = load_and_preprocess_data() | |
X_train, X_test, y_train, y_test = prepare_features(pm25_data) | |
clf, X_test_scaled = train_model(X_train, X_test, y_train, y_test) | |
evaluate_model(clf, X_test_scaled, y_test, X_train) | |
plot_time_series(pm25_data) | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment