Created
May 10, 2022 05:11
-
-
Save XinyueZ/2a2106c06fcd4f628d01c38afa6f21b4 to your computer and use it in GitHub Desktop.
Bernoulli Event Model.ipynb
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"nbformat": 4, | |
"nbformat_minor": 0, | |
"metadata": { | |
"colab": { | |
"name": "Bernoulli Event Model.ipynb", | |
"provenance": [], | |
"machine_shape": "hm", | |
"collapsed_sections": [], | |
"toc_visible": true, | |
"authorship_tag": "ABX9TyNtzqPdQwIqHBtsrdolbfNi", | |
"include_colab_link": true | |
}, | |
"kernelspec": { | |
"name": "python3", | |
"display_name": "Python 3" | |
}, | |
"language_info": { | |
"name": "python" | |
}, | |
"accelerator": "GPU" | |
}, | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "view-in-github", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"<a href=\"https://colab.research.google.com/gist/XinyueZ/2a2106c06fcd4f628d01c38afa6f21b4/bernoulli-event-model.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"id": "eqhmG1gt0ZTA" | |
}, | |
"outputs": [], | |
"source": [ | |
"import csv\n", | |
"import tensorflow as tf\n", | |
"from tensorflow.keras.preprocessing.text import Tokenizer\n", | |
"from tensorflow.keras.preprocessing.sequence import pad_sequences\n", | |
"import sys\n", | |
"from decimal import *\n", | |
"from IPython.display import clear_output" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"# Tools" | |
], | |
"metadata": { | |
"id": "QUNrACM0lkaU" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"def process_in_duration(message, proc):\n", | |
" import datetime\n", | |
" \n", | |
" tm_start = datetime.datetime.now()\n", | |
"\n", | |
" result = proc()\n", | |
"\n", | |
" tm_end = datetime.datetime.now() \n", | |
" delta = tm_end - tm_start\n", | |
" \n", | |
" print(f\"\\033[92m{message} used {delta}\")\n", | |
" return result" | |
], | |
"metadata": { | |
"id": "dWF25vkZlncR" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"# Load data" | |
], | |
"metadata": { | |
"id": "mca0rkXC1f2A" | |
} | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"### Download" | |
], | |
"metadata": { | |
"id": "nTidvrcH1hDP" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"!wget --no-check-certificate \\\n", | |
" https://dl.dropbox.com/s/igq20e1nvwjwxw4/spam_ham_dataset.csv \\\n", | |
" -O ./spam_ham_dataset.csv" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "SGsaEuvt0689", | |
"outputId": "e138da75-172b-412f-b6dc-33397eeca062" | |
}, | |
"execution_count": null, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"--2022-01-10 07:11:05-- https://dl.dropbox.com/s/igq20e1nvwjwxw4/spam_ham_dataset.csv\n", | |
"Resolving dl.dropbox.com (dl.dropbox.com)... 162.125.2.15, 2620:100:6016:15::a27d:10f\n", | |
"Connecting to dl.dropbox.com (dl.dropbox.com)|162.125.2.15|:443... connected.\n", | |
"HTTP request sent, awaiting response... 302 Found\n", | |
"Location: https://dl.dropboxusercontent.com/s/igq20e1nvwjwxw4/spam_ham_dataset.csv [following]\n", | |
"--2022-01-10 07:11:05-- https://dl.dropboxusercontent.com/s/igq20e1nvwjwxw4/spam_ham_dataset.csv\n", | |
"Resolving dl.dropboxusercontent.com (dl.dropboxusercontent.com)... 162.125.1.15, 2620:100:6016:15::a27d:10f\n", | |
"Connecting to dl.dropboxusercontent.com (dl.dropboxusercontent.com)|162.125.1.15|:443... connected.\n", | |
"HTTP request sent, awaiting response... 200 OK\n", | |
"Length: 5502589 (5.2M) [text/csv]\n", | |
"Saving to: ‘./spam_ham_dataset.csv’\n", | |
"\n", | |
"./spam_ham_dataset. 100%[===================>] 5.25M 32.8MB/s in 0.2s \n", | |
"\n", | |
"2022-01-10 07:11:06 (32.8 MB/s) - ‘./spam_ham_dataset.csv’ saved [5502589/5502589]\n", | |
"\n" | |
] | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"### Build dataset" | |
], | |
"metadata": { | |
"id": "5jfzdP3L1jAF" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"sentences = []\n", | |
"labels = []\n", | |
"# Some general stopwords, we dont consider them as part of data.\n", | |
"stopwords = [ \"a\", \"about\", \"above\", \"after\", \"again\", \"against\", \"all\", \"am\", \"an\", \"and\", \"any\", \"are\", \"as\", \"at\", \"be\", \"because\", \"been\", \"before\", \"being\", \"below\", \"between\", \"both\", \"but\", \"by\", \"could\", \"did\", \"do\", \"does\", \"doing\", \"down\", \"during\", \"each\", \"few\", \"for\", \"from\", \"further\", \"had\", \"has\", \"have\", \"having\", \"he\", \"he'd\", \"he'll\", \"he's\", \"her\", \"here\", \"here's\", \"hers\", \"herself\", \"him\", \"himself\", \"his\", \"how\", \"how's\", \"i\", \"i'd\", \"i'll\", \"i'm\", \"i've\", \"if\", \"in\", \"into\", \"is\", \"it\", \"it's\", \"its\", \"itself\", \"let's\", \"me\", \"more\", \"most\", \"my\", \"myself\", \"nor\", \"of\", \"on\", \"once\", \"only\", \"or\", \"other\", \"ought\", \"our\", \"ours\", \"ourselves\", \"out\", \"over\", \"own\", \"same\", \"she\", \"she'd\", \"she'll\", \"she's\", \"should\", \"so\", \"some\", \"such\", \"than\", \"that\", \"that's\", \"the\", \"their\", \"theirs\", \"them\", \"themselves\", \"then\", \"there\", \"there's\", \"these\", \"they\", \"they'd\", \"they'll\", \"they're\", \"they've\", \"this\", \"those\", \"through\", \"to\", \"too\", \"under\", \"until\", \"up\", \"very\", \"was\", \"we\", \"we'd\", \"we'll\", \"we're\", \"we've\", \"were\", \"what\", \"what's\", \"when\", \"when's\", \"where\", \"where's\", \"which\", \"while\", \"who\", \"who's\", \"whom\", \"why\", \"why's\", \"with\", \"would\", \"you\", \"you'd\", \"you'll\", \"you're\", \"you've\", \"your\", \"yours\", \"yourself\", \"yourselves\" ]\n", | |
"avoidwords = [\"Subject:\"]\n", | |
"\n", | |
"with open(\"./spam_ham_dataset.csv\", 'r') as csvfile:\n", | |
" reader = csv.reader(csvfile, delimiter=',')\n", | |
" next(reader)\n", | |
" for row in reader:\n", | |
" labels.append(row[-1])\n", | |
" sentence = row[2]\n", | |
" for word in stopwords:\n", | |
" token = \" \" + word + \" \"\n", | |
" sentence = sentence.replace(token, \" \")\n", | |
" for avoid in avoidwords: \n", | |
" sentence = sentence.replace(avoid, \"\")\n", | |
" sentences.append(sentence)" | |
], | |
"metadata": { | |
"id": "PneociVX1Ghg" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"print(len(labels))\n", | |
"print(len(sentences))\n", | |
"print(sentences[0])" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "i-gly-dS1OXb", | |
"outputId": "8b7e908c-0d1a-46ef-c470-571f358b1cfd" | |
}, | |
"execution_count": null, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"5171\n", | |
"5171\n", | |
" enron methanol ; meter # : 988291\n", | |
"this follow note gave monday , 4 / 3 / 00 { preliminary\n", | |
"flow data provided daren } .\n", | |
"please override pop ' s daily volume { presently zero } reflect daily\n", | |
"activity can obtain gas control .\n", | |
"this change needed asap economics purposes .\n" | |
] | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"# Model" | |
], | |
"metadata": { | |
"id": "Btrj-z0n45BC" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"class BernoulliEventModel:\n", | |
" def __init__(self):\n", | |
" self.parameters = {}\n", | |
"\n", | |
" self.mount_y_1 = 0 # Total lengthes of all input with label 1\n", | |
" self.mount_y_0 = 0 # Total lengthes of all input with label 0\n", | |
"\n", | |
" self.is_spam = \"1\"\n", | |
" self.not_spam = \"0\"\n", | |
"\n", | |
" self.score_list = []\n", | |
"\n", | |
" def preprocess(self, sentences, labels):\n", | |
" self.tokenizer = Tokenizer()\n", | |
" self.tokenizer.fit_on_texts(sentences) \n", | |
" self.tokenizer.texts_to_sequences(sentences)\n", | |
"\n", | |
" # Split training and cross-validation data\n", | |
" X = self.tokenizer.texts_to_sequences(sentences)\n", | |
" Y = labels\n", | |
"\n", | |
" m = len(Y)\n", | |
" sp = int(m * (1-0.01))\n", | |
" X_train = X[:sp]\n", | |
" Y_train = Y[:sp]\n", | |
" X_cv = X[sp:]\n", | |
" Y_cv = Y[sp:]\n", | |
"\n", | |
" # Laplace smoothing\n", | |
" self.U = 1\n", | |
" self.V = len(self.tokenizer.index_word)\n", | |
"\n", | |
" self.mount_y_1 = Y_train.count(self.is_spam)\n", | |
" self.mount_y_0 = Y_train.count(self.not_spam)\n", | |
"\n", | |
" return X_train, Y_train, X_cv, Y_cv\n", | |
"\n", | |
" def fit(self, X, Y, X_cv, Y_cv):\n", | |
"\n", | |
" def zeros(m,n): return [[0 for col in range(n)] for row in range(m)]\n", | |
"\n", | |
" m = len(Y)\n", | |
" \n", | |
" self.parameters[\"∅_y_1\"] = Y.count(self.is_spam) / m\n", | |
" self.parameters[\"∅_y_0\"] = Y.count(self.not_spam) / m\n", | |
"\n", | |
" *tokenizer_indices, = self.tokenizer.index_word\n", | |
" n = len(tokenizer_indices)\n", | |
" presence_table = zeros(m, n+1)\n", | |
"\n", | |
" # Initilize parameters of tokenizer\n", | |
" for tokenizer_index in tokenizer_indices: \n", | |
" self.parameters[f\"∅_{tokenizer_index}_y_1\"] = 0\n", | |
" self.parameters[f\"∅_{tokenizer_index}_y_0\"] = 0\n", | |
"\n", | |
" for (row_index, _) in enumerate(Y):\n", | |
" x = X[row_index]\n", | |
" for tokenizer_index in x: \n", | |
" presence_table[row_index][tokenizer_index] = 1\n", | |
" \n", | |
" for tokenizer_index in tokenizer_indices: \n", | |
" sys.stdout.write(f\"Tokenizer: {tokenizer_index} | \")\n", | |
" for (row_index, y) in enumerate(Y):\n", | |
" if y == self.is_spam: \n", | |
" self.parameters[f\"∅_{tokenizer_index}_y_1\"] += presence_table[row_index][tokenizer_index]\n", | |
" else: \n", | |
" self.parameters[f\"∅_{tokenizer_index}_y_0\"] += presence_table[row_index][tokenizer_index]\n", | |
" self.parameters[f\"∅_{tokenizer_index}_y_1\"] = (self.parameters[f\"∅_{tokenizer_index}_y_1\"] + self.U) / (self.mount_y_1 + self.V)\n", | |
" self.parameters[f\"∅_{tokenizer_index}_y_0\"] = (self.parameters[f\"∅_{tokenizer_index}_y_0\"] + self.U) / (self.mount_y_0 + self.V)\n", | |
" \n", | |
" _, preds_cv = self.predict(X_cv) \n", | |
" score = self.validate(preds_cv, Y_cv)\n", | |
" sys.stdout.write(f\"validate(F-Score): {score}\")\n", | |
"\n", | |
" sys.stdout.flush()\n", | |
" clear_output(wait=True)\n", | |
"\n", | |
" return self.parameters\n", | |
"\n", | |
" def validate(self, pred_list, true_list):\n", | |
" score = self.F_Score(pred_list, true_list)\n", | |
" self.score_list.append(score)\n", | |
" return score\n", | |
"\n", | |
" def validation_score(self):\n", | |
" return self.score_list[-1]\n", | |
"\n", | |
" def F_Score(self, pred_list, true_list): \n", | |
" def bitwise(a, b, core): \n", | |
" return [int(core(pred, true)) for pred, true in zip(pred_list, true_list)]\n", | |
"\n", | |
" true_positive = bitwise(pred_list, true_list, lambda pred, true: pred==1 and true==\"1\") \n", | |
" true_positive = sum(true_positive)\n", | |
" \n", | |
" false_positive = bitwise(pred_list, true_list, lambda pred, true: pred==1 and true==\"0\") \n", | |
" false_positive = sum(false_positive)\n", | |
" \n", | |
" false_negative = bitwise(pred_list, true_list, lambda pred, true: pred==0 and true==\"1\") \n", | |
" false_negative = sum(false_negative)\n", | |
" \n", | |
" score = 0\n", | |
" if (true_positive+false_positive) * (true_positive+false_negative) > 0: \n", | |
" precision = true_positive/(true_positive+false_positive);\n", | |
" recall = true_positive/(true_positive+false_negative);\n", | |
" if (precision+recall) > 0:\n", | |
" score = (2*precision*recall)/(precision+recall);\n", | |
"\n", | |
" return score \n", | |
"\n", | |
" def predict(self, X):\n", | |
" Y_1 = self.parameters[\"∅_y_1\"] \n", | |
" Y_0 = self.parameters[\"∅_y_0\"] \n", | |
"\n", | |
" results = []\n", | |
" for x in X:\n", | |
" uniqued_x = list(set(x)) # Remove duplicated occur.\n", | |
"\n", | |
" A = Decimal(1)\n", | |
" B = Decimal(1)\n", | |
"\n", | |
" for tokenizer_index in uniqued_x:\n", | |
" occur_count = x.count(tokenizer_index)\n", | |
" \n", | |
" if self.parameters[f\"∅_{tokenizer_index}_y_1\"] > 0:\n", | |
" A *= Decimal(Decimal(self.parameters[f\"∅_{tokenizer_index}_y_1\"])**occur_count)\n", | |
"\n", | |
" if self.parameters[f\"∅_{tokenizer_index}_y_0\"] > 0:\n", | |
" B *= Decimal(Decimal(self.parameters[f\"∅_{tokenizer_index}_y_0\"])**occur_count)\n", | |
" \n", | |
" A = Decimal(Decimal(A) * Decimal(Y_1)) \n", | |
" B = Decimal(Decimal(B) * Decimal(Y_0))\n", | |
" Bayes = A / (A + B)\n", | |
" results.append(Bayes) \n", | |
" return results, [int(x>0.5) for x in results]" | |
], | |
"metadata": { | |
"id": "j84SgthofoC9" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"### Training model" | |
], | |
"metadata": { | |
"id": "gZHYV2-ei2cw" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"model = BernoulliEventModel()\n", | |
"X_train, Y_train, X_cv, Y_cv = model.preprocess(sentences, labels)\n", | |
"parameters = process_in_duration(\"Training completed\", lambda: model.fit(X_train[:], Y_train[:],\n", | |
" X_cv, Y_cv))" | |
], | |
"metadata": { | |
"id": "Zah-s1xb47Rh", | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"outputId": "5b4ccafe-dbcc-4964-ac32-144c844f227b" | |
}, | |
"execution_count": null, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"\u001b[92mTraining completed used 0:54:48.870572\n" | |
] | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"##### Final score after traning with cross validation" | |
], | |
"metadata": { | |
"id": "YkwL0ACERynE" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"model.validation_score()" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "utYQjqz4R2j1", | |
"outputId": "5521ecce-ee1a-42f8-a249-c407ba959ef9" | |
}, | |
"execution_count": null, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"0.9032258064516129" | |
] | |
}, | |
"metadata": {}, | |
"execution_count": 8 | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"### Prediction" | |
], | |
"metadata": { | |
"id": "PlmNRl19S9Sw" | |
} | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"##### Test data" | |
], | |
"metadata": { | |
"id": "P_OKNV0Ri07b" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"X_Test = model.tokenizer.texts_to_sequences(\n", | |
" [\n", | |
" \"photoshop , windows , office, office . office. cheap . main trending abasements darer prudently fortuitous Zhao Xinyue\",\n", | |
" \"underpriced issue with high return on equity stock report . dont sieep on this stock ! th...\",\n", | |
" \"jennifer sends them to their final destination . designated as a private key 4 . validat...\",\n", | |
" \"software microsoft windows xp professioznal 2002 retail price : $ 270 . 99 our low pricie\",\n", | |
" \"neon retreat ho ho ho , we ' re around to that most wonderful time of the year\",\n", | |
" \"noms / actual flow for 2 / 26 we agree - - - - - - - - - - - - - - - - - - - - - - forwar...\"\n", | |
" ]\n", | |
") \n", | |
"\n", | |
"_, preds = model.predict(X_Test)\n", | |
"preds" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "X5gMcN5YWdU9", | |
"outputId": "2355ca9b-4b53-4f97-c6bc-e5e8c370b14b" | |
}, | |
"execution_count": null, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"[1, 0, 0, 1, 0, 0]" | |
] | |
}, | |
"metadata": {}, | |
"execution_count": 9 | |
} | |
] | |
} | |
] | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment