Last active
August 16, 2020 10:49
-
-
Save LowriWilliams/6331805a3eb9eebc5ef90c8fd19be39f to your computer and use it in GitHub Desktop.
sms_adversarial/original_classification
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| { | |
| "cells": [ | |
| { | |
| "cell_type": "code", | |
| "execution_count": 9, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "headers = list(word2vec_df)\n", | |
| "headers.remove('label_not_spam')\n", | |
| "headers.remove('label_spam')\n", | |
| "\n", | |
| "X = np.array(word2vec_df[headers].values.tolist())\n", | |
| "y = np.array(word2vec_df[['label_not_spam', 'label_spam']].values.tolist())" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 10, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "Value counts for training \n", | |
| "\n", | |
| "3900\n", | |
| "\n", | |
| "\n", | |
| "Value counts for testing \n", | |
| "\n", | |
| "1672\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "# Split data to training (70%) and testing (30%)\n", | |
| "\n", | |
| "X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3)\n", | |
| "\n", | |
| "print(\"Value counts for training \\n\")\n", | |
| "print(y_train[:, 0].size)\n", | |
| "print(\"\\n\")\n", | |
| "print(\"Value counts for testing \\n\")\n", | |
| "print(y_test[:, 0].size)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 11, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| " precision recall f1-score support\n", | |
| "\n", | |
| " 0 0.97 0.96 0.97 1462\n", | |
| " 1 0.75 0.82 0.78 210\n", | |
| "\n", | |
| " micro avg 0.94 0.94 0.94 1672\n", | |
| " macro avg 0.86 0.89 0.88 1672\n", | |
| "weighted avg 0.95 0.94 0.94 1672\n", | |
| " samples avg 0.94 0.94 0.94 1672\n", | |
| "\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "def classify(X_train, y_train, X_test, y_test):\n", | |
| " # Initialise Decision Tree\n", | |
| " clf = DecisionTreeClassifier()\n", | |
| " # Fit model\n", | |
| " model = clf.fit(X_train, y_train)\n", | |
| " # Predict testing target labels\n", | |
| " prediction = model.predict(X_test)\n", | |
| " \n", | |
| " return prediction\n", | |
| "\n", | |
| "print(classification_report(y_test, classify(X_train, y_train, X_test, y_test)))" | |
| ] | |
| } | |
| ], | |
| "metadata": { | |
| "kernelspec": { | |
| "display_name": "Python 3", | |
| "language": "python", | |
| "name": "python3" | |
| }, | |
| "language_info": { | |
| "codemirror_mode": { | |
| "name": "ipython", | |
| "version": 3 | |
| }, | |
| "file_extension": ".py", | |
| "mimetype": "text/x-python", | |
| "name": "python", | |
| "nbconvert_exporter": "python", | |
| "pygments_lexer": "ipython3", | |
| "version": "3.7.2" | |
| } | |
| }, | |
| "nbformat": 4, | |
| "nbformat_minor": 2 | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment