Last active
August 30, 2021 06:22
-
-
Save shadiakiki1986/689980135fe9dde1d892127bde40a5a1 to your computer and use it in GitHub Desktop.
SVM-RBF sensitivity to translation.ipynb
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"nbformat": 4, | |
"nbformat_minor": 0, | |
"metadata": { | |
"colab": { | |
"name": "SVM-RBF sensitivity to translation.ipynb", | |
"provenance": [], | |
"collapsed_sections": [], | |
"toc_visible": true, | |
"authorship_tag": "ABX9TyOF2Z+MPhxOHaV1DPw1vbgU", | |
"include_colab_link": true | |
}, | |
"kernelspec": { | |
"name": "python3", | |
"display_name": "Python 3" | |
}, | |
"language_info": { | |
"name": "python" | |
} | |
}, | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "view-in-github", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"<a href=\"https://colab.research.google.com/gist/shadiakiki1986/689980135fe9dde1d892127bde40a5a1/svm-rbf-sensitivity-to-translation.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "8GaAZFCMWaw4" | |
}, | |
"source": [ | |
"Testing if RBF is invariant to translation:\n", | |
"\n", | |
"- sklearn issue: https://github.com/scikit-learn/scikit-learn/issues/18432\n", | |
"- gist: https://gist.github.com/xtomasch/84d1d8574ef51eb8d42e77560d647e06\n", | |
"\n", | |
"\n", | |
"Uses my digits dataset with jitter: https://github.com/shadiakiki1986/mnist-digits-jitter\n", | |
"\n", | |
"Published as gist at https://gist.github.com/shadiakiki1986/689980135fe9dde1d892127bde40a5a1" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "HGkCzTmwXS4o" | |
}, | |
"source": [ | |
"# get data" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "wxWb-jzbXTdv" | |
}, | |
"source": [ | |
"# first the original data\n", | |
"from sklearn.datasets import load_digits\n", | |
"digits = load_digits()" | |
], | |
"execution_count": 1, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "RrAkpWbdfPMu", | |
"outputId": "0ac4e1d5-748f-407d-95ce-622f09c8e8db" | |
}, | |
"source": [ | |
"# then the padded and jittered data\n", | |
"!git clone https://github.com/shadiakiki1986/mnist-digits-jitter\n", | |
"\n", | |
"# Update: no need to gunzip since np.loadtxt can automatically do it\n", | |
"#!gunzip mnist-digits-jitter/digits_padded.csv.gz\n", | |
"#!gunzip mnist-digits-jitter/digits_jitter.csv.gz" | |
], | |
"execution_count": 2, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"Cloning into 'mnist-digits-jitter'...\n", | |
"remote: Enumerating objects: 47, done.\u001b[K\n", | |
"remote: Counting objects: 100% (47/47), done.\u001b[K\n", | |
"remote: Compressing objects: 100% (44/44), done.\u001b[K\n", | |
"remote: Total 47 (delta 20), reused 8 (delta 1), pack-reused 0\u001b[K\n", | |
"Unpacking objects: 100% (47/47), done.\n" | |
], | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "66oFvhVEfZMk", | |
"outputId": "db7a154b-a160-47c1-ee81-502f210d5e0b" | |
}, | |
"source": [ | |
"import numpy as np\n", | |
"# np.loadtxt can decompress the files on read\n", | |
"digpad = {\"data\": np.loadtxt(\"mnist-digits-jitter/digits_padded.csv.gz\", delimiter=\",\", dtype=int)}\n", | |
"digjit = {\"data\": np.loadtxt(\"mnist-digits-jitter/digits_jitter.csv.gz\", delimiter=\",\", dtype=int)}\n", | |
"digpad[\"data\"].shape, digjit[\"data\"].shape" | |
], | |
"execution_count": 3, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"((1797, 225), (1797, 225))" | |
] | |
}, | |
"metadata": {}, | |
"execution_count": 3 | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "nBjvLW_9f3YM" | |
}, | |
"source": [ | |
"# convert data to image (not flat)\n", | |
"#def im2data(digxxx_img):\n", | |
"# return np.vstack([img.reshape((-1,1)).squeeze() for img in digxxx_img])\n", | |
"\n", | |
"def data2im(digxxx_data):\n", | |
" s = int(digxxx_data.shape[1]**.5)\n", | |
" l = digxxx_data.reshape((-1,s,s))\n", | |
" return l\n", | |
"\n", | |
"digjit[\"images\"] = data2im(digjit[\"data\"])\n", | |
"digpad[\"images\"] = data2im(digpad[\"data\"])\n", | |
"\n", | |
"assert digjit[\"data\"].shape == (1797, 225)\n", | |
"assert digpad[\"data\"].shape == (1797, 225)\n", | |
"assert digjit[\"images\"].shape == (1797, 15, 15)\n", | |
"assert digpad[\"images\"].shape == (1797, 15, 15)" | |
], | |
"execution_count": 4, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "N8Ukyc9Twqve" | |
}, | |
"source": [ | |
"# run svm and knn" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "VIT0GtY0ww5a", | |
"outputId": "2630bfe0-c5d1-4bfb-f92d-4d5094d7dfbb" | |
}, | |
"source": [ | |
"from sklearn.neighbors import KNeighborsClassifier\n", | |
"from sklearn import metrics, svm, model_selection\n", | |
"import statistics\n", | |
"\n", | |
"clf_l = [\n", | |
" # n neighbors = 3 is better than any of 1,2,4,5\n", | |
" (\"KNN\", KNeighborsClassifier(n_neighbors=3)),\n", | |
" # default is kernel=rbf\n", | |
" # Gamme 1e-3 is better than any of (1e-5,1e-4,1e-2,1e-1)\n", | |
" # Try linear, poly, RBF as xtomasch gist\n", | |
" # https://gist.github.com/xtomasch/84d1d8574ef51eb8d42e77560d647e06\n", | |
" (\"SVM linear\", svm.SVC(kernel=\"linear\")),\n", | |
" (\"SVM poly\", svm.SVC(kernel=\"poly\")),\n", | |
" (\"SVM RBF\", svm.SVC(kernel=\"rbf\", gamma=0.001)),\n", | |
"]\n", | |
"\n", | |
"X_l = [\n", | |
" (\"no jitter\", digpad[\"data\"]),\n", | |
" (\"with jitter\", digjit[\"data\"]),\n", | |
" ]\n", | |
"\n", | |
"\n", | |
"for X_name, X_i in X_l:\n", | |
" for clf_name, clf_i in clf_l:\n", | |
" results = model_selection.cross_val_score(clf_i, X_i, digits.target)\n", | |
" print(f\"{clf_name}, {X_name}: {statistics.mean(results).round(2)}\")" | |
], | |
"execution_count": 5, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"KNN, no jitter: 0.97\n", | |
"SVM linear, no jitter: 0.95\n", | |
"SVM poly, no jitter: 0.96\n", | |
"SVM RBF, no jitter: 0.97\n", | |
"KNN, with jitter: 0.35\n", | |
"SVM linear, with jitter: 0.1\n", | |
"SVM poly, with jitter: 0.27\n", | |
"SVM RBF, with jitter: 0.32\n" | |
], | |
"name": "stdout" | |
} | |
] | |
} | |
] | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment