Created
October 31, 2024 23:13
-
-
Save metehus/3eef42bf8312a725f1439dcc44ec93da to your computer and use it in GitHub Desktop.
Base Covid.ipynb
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"nbformat": 4, | |
"nbformat_minor": 0, | |
"metadata": { | |
"colab": { | |
"provenance": [], | |
"include_colab_link": true | |
}, | |
"kernelspec": { | |
"name": "python3", | |
"display_name": "Python 3" | |
}, | |
"language_info": { | |
"name": "python" | |
} | |
}, | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "view-in-github", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"<a href=\"https://colab.research.google.com/gist/metehus/3eef42bf8312a725f1439dcc44ec93da/base-covid.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"# Base professor abaixo" | |
], | |
"metadata": { | |
"id": "qVtaBoG2l5jJ" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"import pandas as pd\n", | |
"from sklearn.model_selection import StratifiedKFold, GridSearchCV\n", | |
"from sklearn.ensemble import RandomForestClassifier\n", | |
"from sklearn.metrics import classification_report, accuracy_score, f1_score, precision_score, recall_score\n", | |
"from sklearn.preprocessing import StandardScaler\n", | |
"from imblearn.pipeline import Pipeline\n", | |
"from imblearn.over_sampling import SMOTE\n", | |
"\n", | |
"# Carregar dados\n", | |
"lbp_test_url = \"https://gist.github.com/metehus/315e5fd46adbc2dc3433d2e5b5728409/raw/28204aa227b80d03a443988652feb06ba12f0729/lbp-test.csv\"\n", | |
"lbp_train_fold0_url = \"https://gist.github.com/metehus/315e5fd46adbc2dc3433d2e5b5728409/raw/28204aa227b80d03a443988652feb06ba12f0729/lbp-train-fold_0.csv\"\n", | |
"lbp_train_fold1_url = \"https://gist.github.com/metehus/315e5fd46adbc2dc3433d2e5b5728409/raw/28204aa227b80d03a443988652feb06ba12f0729/lbp-train-fold_1.csv\"\n", | |
"lbp_train_fold2_url = \"https://gist.github.com/metehus/315e5fd46adbc2dc3433d2e5b5728409/raw/28204aa227b80d03a443988652feb06ba12f0729/lbp-train-fold_2.csv\"\n", | |
"lbp_train_fold3_url = \"https://gist.github.com/metehus/315e5fd46adbc2dc3433d2e5b5728409/raw/28204aa227b80d03a443988652feb06ba12f0729/lbp-train-fold_3.csv\"\n", | |
"lbp_train_fold4_url = \"https://gist.github.com/metehus/315e5fd46adbc2dc3433d2e5b5728409/raw/28204aa227b80d03a443988652feb06ba12f0729/lbp-train-fold_4.csv\"\n", | |
"\n", | |
"test_df = pd.read_csv(lbp_test_url)\n", | |
"lbp_train_fold0_df = pd.read_csv(lbp_train_fold0_url)\n", | |
"lbp_train_fold1_df = pd.read_csv(lbp_train_fold1_url)\n", | |
"lbp_train_fold2_df = pd.read_csv(lbp_train_fold2_url)\n", | |
"lbp_train_fold3_df = pd.read_csv(lbp_train_fold3_url)\n", | |
"lbp_train_fold4_df = pd.read_csv(lbp_train_fold4_url)\n", | |
"\n", | |
"# Juntar os dfs\n", | |
"df = pd.concat([lbp_train_fold0_df, lbp_train_fold1_df, lbp_train_fold2_df, lbp_train_fold3_df, lbp_train_fold4_df])\n", | |
"\n", | |
"# Preprocessar\n", | |
"df['class'] = df['class'].str.split('/').str[-1]\n", | |
"X = df.drop(columns=['class'])\n", | |
"y = df['class']\n", | |
"\n", | |
"X_test = test_df.drop(columns=['class'])\n", | |
"y_test = test_df['class'].str.split('/').str[-1]\n", | |
"\n", | |
"print(df['class'].value_counts())\n" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "YBSyXcwwGTeS", | |
"outputId": "eebed61d-4e8e-4a28-c086-6b815e49936b" | |
}, | |
"execution_count": 47, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"class\n", | |
"Normal 700\n", | |
"COVID-19 63\n", | |
"Streptococcus 9\n", | |
"SARS 8\n", | |
"Pneumocystis 8\n", | |
"MERS 7\n", | |
"Varicella 7\n", | |
"Name: count, dtype: int64\n" | |
] | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [], | |
"metadata": { | |
"id": "jydi6GMNxNrA" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"from imblearn.pipeline import Pipeline\n", | |
"from imblearn.over_sampling import SMOTE\n", | |
"from sklearn.linear_model import LogisticRegression\n", | |
"from sklearn.model_selection import GridSearchCV, StratifiedKFold\n", | |
"from sklearn.preprocessing import StandardScaler\n", | |
"from sklearn.metrics import classification_report, accuracy_score, f1_score, precision_score, recall_score\n", | |
"\n", | |
"smote = SMOTE(sampling_strategy='auto', k_neighbors=3)\n", | |
"\n", | |
"pipeline = Pipeline([\n", | |
" ('scaler', StandardScaler()), # Normalizar dados\n", | |
" ('smote', smote), # Balanceamento\n", | |
" ('model', LogisticRegression(max_iter=1000, C=0.1))\n", | |
"])\n", | |
"\n", | |
"param_grid = {\n", | |
" 'model__C': [0.1, 1, 10],\n", | |
" 'model__max_iter': [1000],\n", | |
" 'model__class_weight': ['balanced'],\n", | |
" 'smote__k_neighbors': [3, 2]\n", | |
"}\n", | |
"\n", | |
"cv = StratifiedKFold(n_splits=3, shuffle=True, random_state=42)\n", | |
"grid_search = GridSearchCV(pipeline, param_grid, scoring='f1_weighted', cv=cv, n_jobs=-1)\n", | |
"\n", | |
"grid_search.fit(X, y)\n", | |
"print(\"Best parameters found:\", grid_search.best_params_)\n", | |
"\n", | |
"y_pred = grid_search.predict(X_test)\n", | |
"\n", | |
"print(\"* Classificação:\")\n", | |
"print(classification_report(y_test, y_pred, zero_division=1))" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "-mVkVPcF011D", | |
"outputId": "214ae154-f9d9-4b5f-cbc1-19c68ad5d516" | |
}, | |
"execution_count": 55, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"Best parameters found: {'model__C': 10, 'model__class_weight': 'balanced', 'model__max_iter': 1000, 'smote__k_neighbors': 3}\n", | |
"* Classificação:\n", | |
" precision recall f1-score support\n", | |
"\n", | |
" COVID-19 0.64 0.52 0.57 27\n", | |
" MERS 0.75 1.00 0.86 3\n", | |
" Normal 0.98 0.98 0.98 300\n", | |
" Pneumocystis 0.00 0.00 0.00 3\n", | |
" SARS 1.00 1.00 1.00 3\n", | |
"Streptococcus 0.33 0.33 0.33 3\n", | |
" Varicella 0.12 0.33 0.18 3\n", | |
"\n", | |
" accuracy 0.93 342\n", | |
" macro avg 0.55 0.60 0.56 342\n", | |
" weighted avg 0.93 0.93 0.93 342\n", | |
"\n" | |
] | |
} | |
] | |
} | |
] | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment