Skip to content

Instantly share code, notes, and snippets.

@ninovanhooff
Last active November 20, 2018 14:44
Show Gist options
  • Save ninovanhooff/43488aba70e714de959d74a29c0bb485 to your computer and use it in GitHub Desktop.
Save ninovanhooff/43488aba70e714de959d74a29c0bb485 to your computer and use it in GitHub Desktop.
SVM for Fraud Detection
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"metadata": {
"_uuid": "d6d416591c00598f04bfef280548e47217fe9456"
},
"cell_type": "markdown",
"source": "**Fraud Detection with SVM**\n\n"
},
{
"metadata": {
"_uuid": "287b79323023b6a2759534741357a7bb3ad6eff7"
},
"cell_type": "markdown",
"source": "# Fraud detection principal components using SVM and undersampling for correcting unbalance\n\nThis Kernel is based on work by [Davide Vegliante](https://www.kaggle.com/davidevegliante/nn-for-fraud-detection#).\nInstead of a Neural Network, we will train an SVM, and later use a technique described by [Aneesha Bakharia](https://medium.com/@aneesha/visualising-top-features-in-linear-svm-with-scikit-learn-and-matplotlib-3454ab18a14d) to inspect which of the unnamed features V1..V28 contributed most to the decision boundary / margin of the trained SVM. To increase our confidence that our results are sound, we use [Louis Headley's work](https://www.kaggle.com/louish10/anomaly-detection-for-fraud-detection/notebook) to visualize the probability distributions of various features.\n\ntodo: compare to [Variance Threshold](https://scikit-learn.org/stable/modules/feature_selection.html)\n\n## Dataset\n\nThe datasets contains transactions made by credit cards in September 2013 by european cardholders. This dataset presents transactions that occurred in two days, where we have 492 frauds out of 284,807 transactions. The dataset is highly unbalanced, the positive class (frauds) account for 0.172% of all transactions.\n"
},
{
"metadata": {
"_uuid": "8f2839f25d086af736a60e9eeb907d3b93b6e0e5",
"_cell_guid": "b1076dfc-b9ad-4769-8c92-a6c4dae69d19",
"trusted": true,
"scrolled": true
},
"cell_type": "code",
"source": "import numpy as np # linear algebra\nimport pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)\nimport seaborn as sns\nimport matplotlib.pyplot as plt\nimport matplotlib.style\n%matplotlib inline\nimport os\n\nimport warnings \nwarnings.filterwarnings('ignore')\n\n\nmatplotlib.style.use('ggplot')\nfrom cycler import cycler\ncolor_palette = sns.color_palette()\ncolor_palette[0], color_palette[1] = color_palette[1], color_palette[0]\nmatplotlib.rcParams['axes.prop_cycle'] = cycler(color=color_palette)\n\n# read the dataset and print five rows\noriginal_dataset = pd.read_csv('../input/creditcard.csv')\n\ndataset = original_dataset.copy()\nprint(dataset.head(5))\n",
"execution_count": 1,
"outputs": [
{
"output_type": "stream",
"text": " Time V1 V2 V3 ... V27 V28 Amount Class\n0 0.0 -1.359807 -0.072781 2.536347 ... 0.133558 -0.021053 149.62 0\n1 0.0 1.191857 0.266151 0.166480 ... -0.008983 0.014724 2.69 0\n2 1.0 -1.358354 -1.340163 1.773209 ... -0.055353 -0.059752 378.66 0\n3 1.0 -0.966272 -0.185226 1.792993 ... 0.062723 0.061458 123.50 0\n4 2.0 -1.158233 0.877737 1.548718 ... 0.219422 0.215153 69.99 0\n\n[5 rows x 31 columns]\n",
"name": "stdout"
}
]
},
{
"metadata": {
"_uuid": "52378368d3e73e4934641a70614520403b21a4ca"
},
"cell_type": "markdown",
"source": "Let's see how many examples and features our dataset contains. "
},
{
"metadata": {
"_cell_guid": "79c7e3d0-c299-4dcb-8224-4455121ee9b0",
"_uuid": "d629ff2d2480ee46fbb7e2d37f6b5fab8052498a",
"trusted": true
},
"cell_type": "code",
"source": "# count how many entry there are for every class\nclasses_count = pd.value_counts(dataset['Class'])\n\nprint(\"{} Bonafide examples\\n{} Fraud examples\".format(classes_count[0], classes_count[1]))\n\n# classes_count is a Series. \nclasses_count.plot(kind = 'bar')\nplt.xlabel('Classes')\nplt.ylabel('Frequencies')\nplt.title('Fraud Class Hist')",
"execution_count": 2,
"outputs": [
{
"output_type": "stream",
"text": "284315 Bonafide examples\n492 Fraud examples\n",
"name": "stdout"
},
{
"output_type": "execute_result",
"execution_count": 2,
"data": {
"text/plain": "Text(0.5,1,'Fraud Class Hist')"
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": "<Figure size 432x288 with 1 Axes>",
"image/png": "\n"
},
"metadata": {}
}
]
},
{
"metadata": {
"_uuid": "cc01e42af1b5496883736f07fa289d203ea9e974"
},
"cell_type": "markdown",
"source": "The Features `V1`..`V28` seem to be normalized. The creators of the dataset could not disclose what they represent, but did note that this features are already selected from a larger set using [PCA](https://en.wikipedia.org/wiki/Principal_component_analysis). The (transaction) `Amount` and `Time` are not normalized, however. SVM algorithms are not scale invariant, so it is highly recommended to scale these 2 remaing features.\n\nThe `Time` feature represents the number of seconds since data recording started at the moment the transaction was performed. The total dataset covers a timespan of two days. It might be interesting to replace this feature by two new features `Day` (having value 0 or 1) and `TimeOfDay` (having values between 0 and 1, where 0 is 00:00 and 1 is 23:59)\n\nTodo: rescale the Time dimension. Removed for now."
},
{
"metadata": {
"trusted": true,
"_uuid": "76f0cbb25c13d9fce563e985a70df6fd632ed224"
},
"cell_type": "code",
"source": "# Since most of our data has already been scaled we should scale the columns that are left to scale (Amount and Time)\nfrom sklearn.preprocessing import RobustScaler\n\n# RobustScaler is less prone to outliers.\nrob_scaler = RobustScaler()\ndataset['Amount'] = rob_scaler.fit_transform(dataset['Amount'].values.reshape(-1,1))\n\n# remove the Time Feature\ndataset.drop(['Time'], axis = 1, inplace = True)\n\ndataset.head(5)",
"execution_count": 3,
"outputs": [
{
"output_type": "execute_result",
"execution_count": 3,
"data": {
"text/plain": " V1 V2 V3 ... V28 Amount Class\n0 -1.359807 -0.072781 2.536347 ... -0.021053 1.783274 0\n1 1.191857 0.266151 0.166480 ... 0.014724 -0.269825 0\n2 -1.358354 -1.340163 1.773209 ... -0.059752 4.983721 0\n3 -0.966272 -0.185226 1.792993 ... 0.061458 1.418291 0\n4 -1.158233 0.877737 1.548718 ... 0.215153 0.670579 0\n\n[5 rows x 30 columns]",
"text/html": "<div>\n<style scoped>\n .dataframe tbody tr th:only-of-type {\n vertical-align: middle;\n }\n\n .dataframe tbody tr th {\n vertical-align: top;\n }\n\n .dataframe thead th {\n text-align: right;\n }\n</style>\n<table border=\"1\" class=\"dataframe\">\n <thead>\n <tr style=\"text-align: right;\">\n <th></th>\n <th>V1</th>\n <th>V2</th>\n <th>V3</th>\n <th>V4</th>\n <th>V5</th>\n <th>V6</th>\n <th>V7</th>\n <th>V8</th>\n <th>V9</th>\n <th>V10</th>\n <th>V11</th>\n <th>V12</th>\n <th>V13</th>\n <th>V14</th>\n <th>V15</th>\n <th>V16</th>\n <th>V17</th>\n <th>V18</th>\n <th>V19</th>\n <th>V20</th>\n <th>V21</th>\n <th>V22</th>\n <th>V23</th>\n <th>V24</th>\n <th>V25</th>\n <th>V26</th>\n <th>V27</th>\n <th>V28</th>\n <th>Amount</th>\n <th>Class</th>\n </tr>\n </thead>\n <tbody>\n <tr>\n <th>0</th>\n <td>-1.359807</td>\n <td>-0.072781</td>\n <td>2.536347</td>\n <td>1.378155</td>\n <td>-0.338321</td>\n <td>0.462388</td>\n <td>0.239599</td>\n <td>0.098698</td>\n <td>0.363787</td>\n <td>0.090794</td>\n <td>-0.551600</td>\n <td>-0.617801</td>\n <td>-0.991390</td>\n <td>-0.311169</td>\n <td>1.468177</td>\n <td>-0.470401</td>\n <td>0.207971</td>\n <td>0.025791</td>\n <td>0.403993</td>\n <td>0.251412</td>\n <td>-0.018307</td>\n <td>0.277838</td>\n <td>-0.110474</td>\n <td>0.066928</td>\n <td>0.128539</td>\n <td>-0.189115</td>\n <td>0.133558</td>\n <td>-0.021053</td>\n <td>1.783274</td>\n <td>0</td>\n </tr>\n <tr>\n <th>1</th>\n <td>1.191857</td>\n <td>0.266151</td>\n <td>0.166480</td>\n <td>0.448154</td>\n <td>0.060018</td>\n <td>-0.082361</td>\n <td>-0.078803</td>\n <td>0.085102</td>\n <td>-0.255425</td>\n <td>-0.166974</td>\n <td>1.612727</td>\n <td>1.065235</td>\n <td>0.489095</td>\n <td>-0.143772</td>\n <td>0.635558</td>\n <td>0.463917</td>\n <td>-0.114805</td>\n <td>-0.183361</td>\n <td>-0.145783</td>\n <td>-0.069083</td>\n <td>-0.225775</td>\n <td>-0.638672</td>\n <td>0.101288</td>\n <td>-0.339846</td>\n <td>0.167170</td>\n <td>0.125895</td>\n <td>-0.008983</td>\n <td>0.014724</td>\n <td>-0.269825</td>\n <td>0</td>\n </tr>\n <tr>\n <th>2</th>\n <td>-1.358354</td>\n <td>-1.340163</td>\n <td>1.773209</td>\n <td>0.379780</td>\n <td>-0.503198</td>\n <td>1.800499</td>\n <td>0.791461</td>\n <td>0.247676</td>\n <td>-1.514654</td>\n <td>0.207643</td>\n <td>0.624501</td>\n <td>0.066084</td>\n <td>0.717293</td>\n <td>-0.165946</td>\n <td>2.345865</td>\n <td>-2.890083</td>\n <td>1.109969</td>\n <td>-0.121359</td>\n <td>-2.261857</td>\n <td>0.524980</td>\n <td>0.247998</td>\n <td>0.771679</td>\n <td>0.909412</td>\n <td>-0.689281</td>\n <td>-0.327642</td>\n <td>-0.139097</td>\n <td>-0.055353</td>\n <td>-0.059752</td>\n <td>4.983721</td>\n <td>0</td>\n </tr>\n <tr>\n <th>3</th>\n <td>-0.966272</td>\n <td>-0.185226</td>\n <td>1.792993</td>\n <td>-0.863291</td>\n <td>-0.010309</td>\n <td>1.247203</td>\n <td>0.237609</td>\n <td>0.377436</td>\n <td>-1.387024</td>\n <td>-0.054952</td>\n <td>-0.226487</td>\n <td>0.178228</td>\n <td>0.507757</td>\n <td>-0.287924</td>\n <td>-0.631418</td>\n <td>-1.059647</td>\n <td>-0.684093</td>\n <td>1.965775</td>\n <td>-1.232622</td>\n <td>-0.208038</td>\n <td>-0.108300</td>\n <td>0.005274</td>\n <td>-0.190321</td>\n <td>-1.175575</td>\n <td>0.647376</td>\n <td>-0.221929</td>\n <td>0.062723</td>\n <td>0.061458</td>\n <td>1.418291</td>\n <td>0</td>\n </tr>\n <tr>\n <th>4</th>\n <td>-1.158233</td>\n <td>0.877737</td>\n <td>1.548718</td>\n <td>0.403034</td>\n <td>-0.407193</td>\n <td>0.095921</td>\n <td>0.592941</td>\n <td>-0.270533</td>\n <td>0.817739</td>\n <td>0.753074</td>\n <td>-0.822843</td>\n <td>0.538196</td>\n <td>1.345852</td>\n <td>-1.119670</td>\n <td>0.175121</td>\n <td>-0.451449</td>\n <td>-0.237033</td>\n <td>-0.038195</td>\n <td>0.803487</td>\n <td>0.408542</td>\n <td>-0.009431</td>\n <td>0.798278</td>\n <td>-0.137458</td>\n <td>0.141267</td>\n <td>-0.206010</td>\n <td>0.502292</td>\n <td>0.219422</td>\n <td>0.215153</td>\n <td>0.670579</td>\n <td>0</td>\n </tr>\n </tbody>\n</table>\n</div>"
},
"metadata": {}
}
]
},
{
"metadata": {
"_uuid": "4c9da2de181f0f0437433e5c037d1516e7104404"
},
"cell_type": "markdown",
"source": "### Undersampling with ratio 1"
},
{
"metadata": {
"trusted": true,
"_uuid": "a03ba047bd2ce7e44ee5d4847960803393cb5c0b"
},
"cell_type": "code",
"source": "X = dataset.loc[:, dataset.columns != 'Class' ]\ny = dataset.loc[:, dataset.columns == 'Class' ]\n\nfrom imblearn.under_sampling import RandomUnderSampler\nrus = RandomUnderSampler(random_state = 0, sampling_strategy = 1.0)\n\nX_resampled, y_resampled = rus.fit_resample(X, y)\n\nfrom sklearn.model_selection import train_test_split\nX_train, X_test, y_train, y_test = train_test_split(X_resampled, y_resampled, test_size = 0.20, stratify = y_resampled)\n\nassert len(y_train[y_train == 1]) + len(y_test[y_test == 1]) == len(dataset[dataset.Class == 1])\nprint(\"train_set size: {} - Class0: {}, Class1: {}\".format( len(y_train), len(y_train[y_train == 0]), len(y_train[y_train == 1]) ))\nprint(\"test_set size: {} - Class0: {}, Class1: {}\".format( len(y_test), len(y_test[y_test == 0]), len(y_test[y_test == 1]) ))\n",
"execution_count": 4,
"outputs": [
{
"output_type": "stream",
"text": "Using TensorFlow backend.\n",
"name": "stderr"
},
{
"output_type": "stream",
"text": "train_set size: 787 - Class0: 393, Class1: 394\ntest_set size: 197 - Class0: 99, Class1: 98\n",
"name": "stdout"
}
]
},
{
"metadata": {
"_uuid": "2a12bb26240b3169a8283fe4ae59f3c295e12a55"
},
"cell_type": "markdown",
"source": "### Train SVM Structure\n"
},
{
"metadata": {
"trusted": true,
"_uuid": "5442cb2946880c16b0715b66bc17675ff42344df"
},
"cell_type": "code",
"source": "from sklearn.feature_extraction.text import CountVectorizer\nfrom sklearn.svm import LinearSVC\n\nclassifier = LinearSVC(dual=False)\nclassifier.fit(X_train, y_train.ravel())",
"execution_count": 5,
"outputs": [
{
"output_type": "execute_result",
"execution_count": 5,
"data": {
"text/plain": "LinearSVC(C=1.0, class_weight=None, dual=False, fit_intercept=True,\n intercept_scaling=1, loss='squared_hinge', max_iter=1000,\n multi_class='ovr', penalty='l2', random_state=None, tol=0.0001,\n verbose=0)"
},
"metadata": {}
}
]
},
{
"metadata": {
"_uuid": "e7ab25df2181b069f419ca8856be73754e39d991"
},
"cell_type": "markdown",
"source": "### Inspect the coefficients"
},
{
"metadata": {
"trusted": true,
"_uuid": "5ff51c723edea0ad7f6160fc5e79528665d8ab7b",
"scrolled": true
},
"cell_type": "code",
"source": "def plot_coefficients(classifier, feature_names, top_features=-1):\n if top_features == -1:\n top_features = len(feature_names)\n \n coef = classifier.coef_.ravel()\n abs_coef = np.abs(coef)\n top_coefficients = np.argsort(-abs_coef)[-top_features:]\n\n # create plot\n plt.clf()\n plt.figure(figsize=(15, 3))\n colors = [color_palette[c > 0] for c in coef[top_coefficients]]\n plt.bar(np.arange(top_features), coef[top_coefficients], color=colors)\n feature_names = np.array(feature_names)\n plt.xticks(np.arange(0, top_features), feature_names[top_coefficients], rotation=60, ha='right')\n plt.title(\"Feature coefficients\")\n plt.ylabel(\"Coefficient\")\n\nfeature_names = list(X.columns.values)\nplot_coefficients(classifier, feature_names)\nplt.show()",
"execution_count": 6,
"outputs": [
{
"output_type": "display_data",
"data": {
"text/plain": "<Figure size 432x288 with 0 Axes>"
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": "<Figure size 1080x216 with 1 Axes>",
"image/png": "\n"
},
"metadata": {}
}
]
},
{
"metadata": {
"trusted": true,
"_uuid": "29bda89500394769ddbe6bffa99da325df8e2320"
},
"cell_type": "code",
"source": "dataset.groupby(\"Class\")['V14', 'V15', 'V19'].describe(percentiles=[])",
"execution_count": 7,
"outputs": [
{
"output_type": "execute_result",
"execution_count": 7,
"data": {
"text/plain": " V14 ... V19 \n count mean std ... min 50% max\nClass ... \n0 284315.0 0.012064 0.897007 ... -7.213527 0.003117 5.591971\n1 492.0 -6.971723 4.278940 ... -3.681904 0.646807 5.228342\n\n[2 rows x 18 columns]",
"text/html": "<div>\n<style scoped>\n .dataframe tbody tr th:only-of-type {\n vertical-align: middle;\n }\n\n .dataframe tbody tr th {\n vertical-align: top;\n }\n\n .dataframe thead tr th {\n text-align: left;\n }\n\n .dataframe thead tr:last-of-type th {\n text-align: right;\n }\n</style>\n<table border=\"1\" class=\"dataframe\">\n <thead>\n <tr>\n <th></th>\n <th colspan=\"6\" halign=\"left\">V14</th>\n <th colspan=\"6\" halign=\"left\">V15</th>\n <th colspan=\"6\" halign=\"left\">V19</th>\n </tr>\n <tr>\n <th></th>\n <th>count</th>\n <th>mean</th>\n <th>std</th>\n <th>min</th>\n <th>50%</th>\n <th>max</th>\n <th>count</th>\n <th>mean</th>\n <th>std</th>\n <th>min</th>\n <th>50%</th>\n <th>max</th>\n <th>count</th>\n <th>mean</th>\n <th>std</th>\n <th>min</th>\n <th>50%</th>\n <th>max</th>\n </tr>\n <tr>\n <th>Class</th>\n <th></th>\n <th></th>\n <th></th>\n <th></th>\n <th></th>\n <th></th>\n <th></th>\n <th></th>\n <th></th>\n <th></th>\n <th></th>\n <th></th>\n <th></th>\n <th></th>\n <th></th>\n <th></th>\n <th></th>\n <th></th>\n </tr>\n </thead>\n <tbody>\n <tr>\n <th>0</th>\n <td>284315.0</td>\n <td>0.012064</td>\n <td>0.897007</td>\n <td>-18.392091</td>\n <td>0.051947</td>\n <td>10.526766</td>\n <td>284315.0</td>\n <td>0.000161</td>\n <td>0.915060</td>\n <td>-4.391307</td>\n <td>0.048294</td>\n <td>8.877742</td>\n <td>284315.0</td>\n <td>-0.001178</td>\n <td>0.811733</td>\n <td>-7.213527</td>\n <td>0.003117</td>\n <td>5.591971</td>\n </tr>\n <tr>\n <th>1</th>\n <td>492.0</td>\n <td>-6.971723</td>\n <td>4.278940</td>\n <td>-19.214325</td>\n <td>-6.729720</td>\n <td>3.442422</td>\n <td>492.0</td>\n <td>-0.092929</td>\n <td>1.049915</td>\n <td>-4.498945</td>\n <td>-0.057227</td>\n <td>2.471358</td>\n <td>492.0</td>\n <td>0.680659</td>\n <td>1.539853</td>\n <td>-3.681904</td>\n <td>0.646807</td>\n <td>5.228342</td>\n </tr>\n </tbody>\n</table>\n</div>"
},
"metadata": {}
}
]
},
{
"metadata": {
"trusted": true,
"_uuid": "aa6c3500c1799457d9b201f7b7903cbca69e3a50"
},
"cell_type": "code",
"source": "def remove_outliers(series):\n return series[np.abs(series-series.mean()) <= (1*series.std())]\n\nfor feature in ['V14', 'V15', 'V19']:\n ax = plt.subplot()\n positive = dataset[feature][dataset.Class == 1]\n negative = dataset[feature][dataset.Class == 0]\n\n sns.distplot(positive, bins=50, label='Fraudulent')\n sns.distplot(negative, bins=50, label='Bonafide')\n ax.set_xlabel('')\n ax.set_title('histogram of feature: ' + str(feature))\n plt.legend(loc='best')\n plt.show()",
"execution_count": 8,
"outputs": [
{
"output_type": "display_data",
"data": {
"text/plain": "<Figure size 432x288 with 1 Axes>",
"image/png": "\n"
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": "<Figure size 432x288 with 1 Axes>",
"image/png": "\n"
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": "<Figure size 432x288 with 1 Axes>",
"image/png": "\n"
},
"metadata": {}
}
]
},
{
"metadata": {
"_uuid": "b9356fb87f8fa3fd43929c2432499f0319df78ec"
},
"cell_type": "markdown",
"source": "On first inspection, it seems that coefficient strength is indeed a good indicator for the importance of a feature. When the coefficient is positive, the feature is positively correlated with the output class. In other words: Examples with higher `V14` values are more likely to be fraudulent than lower values. Conversely, high values for `V19` are indicative for a bonafide transaction. However, the coefficient for `V19` is much weaker than the coefficient for `V14`, and thus we expect `V19` to have much less of an influence on our prediction. We also expect the probability distributions for Fraudulent and Bonafide transaction to overlap more for `V19` than for `V14`. The coefficient of `V15` is even weaker than that of `V19`. And indeed,, we see that there is larger overlap in the probabilities of fraudulent and bonafide transations for `V15` than for `V19`. All of these predictions were made based on the trained SVM coefficients, and all of them are in line with the probability distributions shown above."
},
{
"metadata": {
"_uuid": "217fdf976652693f375da858156d2d87c902a4c8"
},
"cell_type": "markdown",
"source": "## Report the classifier performance\n\nNote that **preciosion and recall are unreliable in an unbalanced dataset**, but we used undersampling to account for this."
},
{
"metadata": {
"trusted": true,
"_uuid": "0ca7e49ef2578760b4db7ad3abe3f7a54a3d22db"
},
"cell_type": "code",
"source": "from sklearn.metrics import classification_report\ny_test_pred = classifier.predict(X_test) > 0.5\ntarget_names = [\"Bonafide\", \"Fraudulent\"]\nprint(classification_report(y_test, y_test_pred, target_names=target_names))",
"execution_count": 9,
"outputs": [
{
"output_type": "stream",
"text": " precision recall f1-score support\n\n Bonafide 0.88 0.93 0.91 99\n Fraudulent 0.92 0.88 0.90 98\n\n micro avg 0.90 0.90 0.90 197\n macro avg 0.90 0.90 0.90 197\nweighted avg 0.90 0.90 0.90 197\n\n",
"name": "stdout"
}
]
},
{
"metadata": {
"_uuid": "7f7165f5eabee87292413968be3cac1f746435d8"
},
"cell_type": "markdown",
"source": "### Confusion Matrix"
},
{
"metadata": {
"trusted": true,
"_uuid": "ea6a94b895724239172956e405e298224dbd4422"
},
"cell_type": "code",
"source": "\nfrom sklearn.metrics import confusion_matrix\ncm = confusion_matrix(y_test, y_test_pred)\n\nplt.clf()\nplt.grid('off')\nplt.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues)\nclassNames = target_names\nplt.title('Fraud or Not Fraud Confusion Matrix - Test Data')\nplt.ylabel('True label')\nplt.xlabel('Predicted label')\ntick_marks = np.arange(len(classNames))\nplt.xticks(tick_marks, classNames, rotation=45)\nplt.yticks(tick_marks, classNames)\ns = [['TN','FP'], ['FN', 'TP']]\nthresh = cm.max() / 2.\nfor i in range(2):\n for j in range(2):\n plt.text(j,i, str(s[i][j])+\" = \"+str(cm[i][j]), horizontalalignment=\"center\",\n color=\"white\" if cm[i, j] > thresh else \"black\")\nplt.show()\n",
"execution_count": 10,
"outputs": [
{
"output_type": "display_data",
"data": {
"text/plain": "<Figure size 432x288 with 1 Axes>",
"image/png": "\n"
},
"metadata": {}
}
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"name": "python",
"version": "3.6.6",
"mimetype": "text/x-python",
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"pygments_lexer": "ipython3",
"nbconvert_exporter": "python",
"file_extension": ".py"
}
},
"nbformat": 4,
"nbformat_minor": 1
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment