Skip to content

Instantly share code, notes, and snippets.

@karlnapf
Created June 28, 2016 15:26
Show Gist options
  • Save karlnapf/1000e26a40fe74d08237841a2614080c to your computer and use it in GitHub Desktop.
Save karlnapf/1000e26a40fe74d08237841a2614080c to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 45,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"import numpy as np\n",
"import scipy as sp\n",
"from scipy.spatial.distance import pdist, squareform\n",
"import matplotlib.pyplot as plt\n",
"%matplotlib inline\n",
"import modshogun as sg\n",
"from sklearn.cross_validation import KFold"
]
},
{
"cell_type": "code",
"execution_count": 35,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"NX=10\n",
"NY=5\n",
"NXY = NX+NY\n",
"\n",
"D=2\n",
"X = np.random.randn(NX,D)\n",
"Y = np.random.randn(NY,D)+3"
]
},
{
"cell_type": "code",
"execution_count": 30,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"XY = np.vstack((X,Y))"
]
},
{
"cell_type": "code",
"execution_count": 31,
"metadata": {
"collapsed": false
},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAP4AAAD7CAYAAABKWyniAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAEL9JREFUeJzt3WuwldV9x/Hvz3qJAbzGSyMjR2K8jEa8RW3Uuqt1QJOo\no3XG2CSijS8aL0x0rGheeOwri01TRuMLI+K9TgWnko4XZMhxqp3YAOIFwWC9Yga8kBQwGQP674u9\ngRM8+5zNWvt5Nmb9PjN7Zl+e//4vOOd3nn151rMUEZhZWbbr9QDMrH4OvlmBHHyzAjn4ZgVy8M0K\n5OCbFWj7qhtI8veFZj0SERrq/sqD33RDm/sHgMYwdbtk9Fw/zGNPAqdnPPdwjhjmsfuAb1fUd+dh\nHrsLmNz+4b5GWss37hvmwYeBc4d5/N20nsC+cU7bx9b2T2dM/5S2j6/UAcl9WTZkhppu6Ycr+ts+\n/O7BY5Lb7v3E2vYP3tsP32nTd1L78fqlvlmBsoIvaZKkZZJ+Jenabg3KzKqVHHxJ2wG3AhOBw4Bv\nSTpk656lL7V9pvE96jvc24AqHdmDnof2oCfs2Di+J305rtGbvkek9c3Z4x8HLI+INyNiPfAgcPbW\nPUVfRvscX+pRXwe/ajs1TuhJX45v9KbvhLS+OcHfD3h70O0VrfvMbBtX06f6A4Ou99G7Pb3Zn7Dn\nB+CFgY42zQn+O8D+g26Pbd03hEZGGzPryITGH7/0v//GtpvmvNT/JXCgpHGSdgQuAOZkPJ+Z1SR5\njx8RH0u6HJhL8w/IjIhY2rWRmVllst7jR8TjwMFdGouZ1cRH7pkVyME3K1BNX+elTrZZ09VRdC7n\n4JMVGbWXZtQuTC/9RmLdravTe2b8bFfOzjny8oP00gf3TC792Q3fTO/7UnppO97jmxXIwTcrkINv\nViAH36xADr5ZgRx8swI5+GYFcvDNCuTgmxXIwTcrkINvViAH36xADr5ZgRx8swLVNC13uHXstkW/\nz6jdIaM2Y8ooG9JL389o2wu7ffb6jmGY9e9GUkFKvcc3K5CDb1agnLXzxkqaL2mJpBclXdnNgZlZ\ndXLePWwAroqIxZJGAwslzY2IZV0am5lVJHmPHxErI2Jx6/o6YCleO8/sM6Er7/El9dFckvXZbjyf\nmVUrO/itl/mzgCmtPb+ZbeOyviGUtD3N0N8bEY+03/LJQdfH07v16c3+hC0fgFcHOto099CAO4GX\nI2L68JudntnGzEb05UbzstETFayWK+lE4G+BUyU9J2mRpEmpz2dm9clZLfcZ4M+6OBYzq4mP3DMr\nkINvViAH36xANU3LTZWzam3O1No3Mmqvyqi9P6P2ovTSB1NX+D0jvWfGsV6jT3gvuXYdGYeaNNJX\ny+3L+Z3KmHHdjvf4ZgVy8M0K5OCbFcjBNyuQg29WIAffrEAOvlmBHHyzAjn4ZgVy8M0K5OCbFcjB\nNyuQg29WIAffrECKiGobSAGPJlanTheFvFVrz82o/Zf00pP602uf/mly6V4xManuPd2Z3BMOTK48\nL+NXdvbF306uHTczfZGoN6anTzE/c8rspLrHdB4RoaEe8x7frEAOvlmBurGSznatU2vP6caAzKx6\n3djjTwFe7sLzmFlNsoIvaSxwJnBHd4ZjZnXI3eP/GLgGqParATPrqpwltL4OrIqIxYBaFzP7DMg5\nvfaJwFmSzgR2BsZIuicivvvpTe8bdP2I1sXMuumDgZdYPbCko21z1s67HrgeQNIpwNVDhx4g/aAJ\nM+vMno3D2bNx+Kbbr97472239ff4ZgXqyko6EfEU8FQ3nsvMquc9vlmBHHyzAjn4ZgXaxlfLvTSj\n9oOM2oxVa7Om1mbUfiG99j09mVY4Nr0nK9Ynlzb4++Ta2U+nf8N0IQ8k17I6vXQBx6YXt+E9vlmB\nHHyzAjn4ZgVy8M0K5OCbFcjBNyuQg29WIAffrEAOvlmBHHyzAjn4ZgVy8M0K5OCbFcjBNytQTavl\n/jyxekxG5w0ZtV/JqM2Y0vuFjGnI7/cnl45ed1lS3brR/5HcExrJlb9fd1By7c4L0n/fB045Prn2\n5/qf5NqzIm2l3WO01KvlmtlmDr5ZgXLXzttV0kOSlkpaIin9tZCZ1Sb31FvTgUcj4nxJ2wOf78KY\nzKxiycGXtAtwckRMBoiIDcCaLo3LzCqU81L/AOB9STMlLZJ0u6SduzUwM6tOzkv97YGjgcsiYoGk\nfwWmAjd8etO7Bl0/snUxs25aMPAhCwd+19G2OcFfAbwdEQtat2cB1w696eSMNmbWiWMbozi2MWrT\n7dtvfL/ttskv9SNiFfC2pI1HU5wGvJz6fGZWn9xP9a8E7pe0A/AacHH+kMysalnBj4jnga92aSxm\nVhMfuWdWIAffrEAOvlmB6lktt6+RVveNjJ7tv8kY2YMrkkv3ionJtcmr1pI+tRZg3eifpBUu60/u\nSUbpJaNmpBdn/E7dtvb7ybXT4vXk2r4ZqV+WDTkjF/Ae36xIDr5ZgRx8swI5+GYFcvDNCuTgmxXI\nwTcrkINvViAH36xADr5ZgRx8swI5+GYFcvDNClTTopn3Jlav7upYOndGRm3Goplj+9NrV/w0vXZZ\n4mKdh/Sn92S/5MpT48+Ta+efkT49b4//fCe59oOHxybXHnP+fyXVLdLJXjTTzDZz8M0K5OCbFSh3\ntdzrWqvkviDpfkk7dmtgZlad5OBLGgdcChwVEUfQPI3XBd0amJlVJ+ece2uAPwCjJH1Cc4nsX3dl\nVGZWqZwltH4D/Ah4C3gH+G1EzOvWwMysOsl7fEnjgR8A44D/A2ZJujAiHvj01g8Pun5o62Jm3bR2\n4DnWDjzX0bY5L/WPBZ6JiNUAkh4GvgYMEfxzM9qYWSfGNI5iTOOoTbdX3jiz7bY5n+q/Apwg6XOS\nRHO13KUZz2dmNcl5j/88cA+wEHie5tn7b+/SuMysQrmr5d4M3NylsZhZTXzknlmBHHyzAtWzaCbv\nJtat6eooOvdsRu2B6aUr1mf0baSX9qcWpk+tbR76kWb+jMRpxACPpy9euXrqAcm1026+PLl20XUn\nJde24z2+WYEcfLMCOfhmBXLwzQrk4JsVyME3K5CDb1YgB9+sQA6+WYEcfLMCOfhmBXLwzQrk4JsV\nyME3K1Atq+XuG/+bVLty9vj0xrull44+4b3k2omjnkiubTCQXPu9D2ck114yKq12FXsn95w/I33V\nWr7Xn17LP2TUTsuo/UpG7drEuou9Wq6ZbebgmxVoxOBLmiFplaQXBt23u6S5kl6R9ISkXasdppl1\nUyd7/JnAxC3umwrMi4iDgfnAdd0emJlVZ8TgR8TTwG+2uPts4O7W9buBc7o8LjOrUOp7/L0jYhVA\nRKyEjI93zax23fpwr9rvBM2sq1JPr71K0j4RsUrSvoxw/uy1/dM3Xd+xcTw7NU5IbGtm7S1rXUbW\nafDVumw0B5gM/BNwEfDIcMVj+qd02MbM0h3SumzUPpadfJ33APDfwEGS3pJ0MXATcLqkV2iukntT\n1njNrFYj7vEj4sI2D/11l8diZjXxkXtmBXLwzQrk4JsVqJZpufBJYvXqro6lcxmr9E5OX1GVp9NL\nuSOjNnWGbM4irhmr1sI+GbU5U2uvSC+9fM/02lsfSyw809NyzWwzB9+sQA6+WYEcfLMCOfhmBXLw\nzQrk4JsVyME3K5CDb1YgB9+sQA6+WYEcfLMCOfhmBXLwzQpUz7TcZYk9HsxonLFaLo300nETOjvL\n6VAu5IHk2omkr9J7G99Pqpv3cfrZ11ZP3S+5ln/uT6/NmVrLLRm1Z2TUrkis+xtPyzWzzRx8swKl\nrpY7TdJSSYslzZa0S7XDNLNuSl0tdy5wWEQcCSzHq+WafaYkrZYbEfMiYuOJ9H4BjK1gbGZWkW68\nx78ESD0boJn1QFbwJf0QWB8R6d9DmVntUlfLRdJk4Ezg1BE3vqV/8/XjGnB8I7WtmbX1ErCkoy2T\nVsuVNAm4BvjLiPhoxOor+jtsY2bpDm9dNnqo7Zapq+XeAowGnpS0SNJteQM2szqlrpY7s4KxmFlN\nfOSeWYEcfLMCOfhmBaplWu67MTqp9md8M7nvGNYm1/bxRnLtV6e/lFybszhw/z+m114ceyXVjXvo\nveSe086/PLn2WjWSa7n8vPTaW59Nr806xu3cxLoJnpZrZps5+GYFcvDNCuTgmxXIwTcrkINvViAH\n36xADr5ZgRx8swI5+GYFcvDNCuTgmxXIwTcrkINvVqB6Vst9PLFHxgzX9PMHAxvSS8+4+uHk2gUc\nm1z7OJOSa4+Z8XJS3dF/93Ryz0XXnZRcy013pdeyT0bt7zJqv5xRm/o7daOn5ZrZZg6+WYGSVssd\n9NjVkj6RtEc1wzOzKqSuloukscDpwJvdHpSZVStptdyWH9NcTcfMPmOS3uNLOgt4OyJe7PJ4zKwG\nW/2ll6SdgetpvszfdPewRff2b75+RAMmNLa2rZmN6I3WZWQp33Z/CegDnpckYCywUNJxEfHukBXf\n6U9oY2Zbp6912eiptltu9Wq5EfESsO+mB6TXgaMjYqjPAcxsG5S6Wu5gwUgv9c1sm9LJp/oXRsQX\nI2KniNg/ImZu8fj4iEhbA+b5gaSybMt70/eDgZxjkNMtGPiw9p5rB56rvWfTsh717c3PttP39Fvq\n7ZF7Lwz0pu+rvem7emBJT/ouHMg5xjxNecHvzc/2sxl8M+sJB9+sQPVMyzWznmg3Lbfy4JvZtscv\n9c0K5OCbFahnwZc0SdIySb+SdG1NPcdKmi9piaQXJV1ZR99W7+0kLZI0p8aeu0p6SNLS1r/5+Jr6\nXtfq94Kk+yXtWFGfT50rQtLukuZKekXSE5J2ranvtNb/82JJsyXtUnXPQY9t9XkxehJ8SdsBt9Kc\n538Y8C1Jh9TQegNwVUQcBvwFcFlNfQGmAGknt0s3HXg0Ig4FJgBLq24oaRxwKXBURBxB87DwCypq\nN9S5IqYC8yLiYGA+cF1NfecCh0XEkcDyCvp29bwYvdrjHwcsj4g3I2I98CBwdtVNI2JlRCxuXV9H\nMwj7Vd239cM5E7ij6l6Deu4CnLzxSMuI2BARa2povQb4AzBK0vbA54FfV9Gozbkizgbubl2/Gzin\njr4RMS8iPmnd/AXNyWuV9mxJOi9Gr4K/H/D2oNsrqCGAg0nqA44Enq2h3cYfTp1foRwAvC9pZust\nxu2tKdWVak3W+hHwFvAO8NuImFd130H2johVrbGsBPausfdGlwCPVd0k57wYRX64J2k0MAuY0trz\nV9nr68Cq1iuNTbMca7A9cDTwk4g4mua5oadW3VTSeOAHwDjgi8BoSRdW3XcYtX5fLemHwPqIeKDi\nPhvPi3HD4Ls7re9V8N8B9h90e2zrvsq1Xn7OAu6NiEdqaHkicJak14B/A/5K0j019F1Bc2+woHV7\nFs0/BFU7FngmIlZHxMc0Twr/tRr6brRK0j4AkvYFhj5HRAUkTab5lq6OP3SDz4vxOpvPi9HRK5xe\nBf+XwIGSxrU+8b0AqOvT7juBlyNieh3NIuL61qzG8TT/nfMj4rs19F0FvC3poNZdp1HPh4uvACdI\n+lzrRC2nUe2Hilu+ipoDTG5dvwio6o/7H/WVNInm27mzIuKjqntGxEsRsW9rduwBNP/QH9X2ZDhb\nioieXIBJNH9JlgNTa+p5IvAxsBh4DlgETKrx33wKMKfGfhNo/pFdTHPPu2tNfa+hOV3tBZofsO1Q\nUZ8HaH5w+BHNzxQuBnYH5rV+t+YCu9XUdznNT9YXtS63Vd1zi8dfA/bo9Pl8yK5ZgYr8cM+sdA6+\nWYEcfLMCOfhmBXLwzQrk4JsVyME3K5CDb1ag/wesFHOm67vFaQAAAABJRU5ErkJggg==\n",
"text/plain": [
"<matplotlib.figure.Figure at 0x7f51d8c90e90>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"D = squareform(pdist(XY, 'sqeuclidean'))\n",
"plt.imshow(D, interpolation='nearest');"
]
},
{
"cell_type": "code",
"execution_count": 50,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"20.1246117975\n"
]
}
],
"source": [
"num_repetitions = 3\n",
"num_folds = 3\n",
"log_sigmas = np.linspace(-2,2,3)\n",
"num_kernels = len(log_sigmas)\n",
"num_null_samples = 5\n",
"\n",
"params = (num_repetitions, num_folds, num_kernels, num_null_samples, 3)\n",
"print np.sqrt(np.prod(params))"
]
},
{
"cell_type": "code",
"execution_count": 79,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"p [ 9 4 5 6 7 3 1 2 8 0 11 14 12 13 10]\n",
"train_x [4 5 6 7 8 9]\n",
"train_y [12 13 14]\n"
]
},
{
"ename": "NameError",
"evalue": "name 'is_in_xvalidation_trainig' is not defined",
"output_type": "error",
"traceback": [
"\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[1;31mNameError\u001b[0m Traceback (most recent call last)",
"\u001b[1;32m<ipython-input-79-7d8a1e18547f>\u001b[0m in \u001b[0;36m<module>\u001b[1;34m()\u001b[0m\n\u001b[0;32m 49\u001b[0m \u001b[1;32mfor\u001b[0m \u001b[0mcol_idx\u001b[0m \u001b[1;32min\u001b[0m \u001b[0mrange\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mNY\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 50\u001b[0m \u001b[1;31m# some offset stuff going on inside here\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 51\u001b[1;33m \u001b[1;32mif\u001b[0m \u001b[0mis_in_xvalidation_trainig\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mtrain_x\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mtrain_y\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mrow_idx\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mcol_idx\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 52\u001b[0m \u001b[0mwhich_mmd_term\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mget_mmd_term_to_add\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mrow_idx\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mcol_idx\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 53\u001b[0m \u001b[1;32mfor\u001b[0m \u001b[0mkernel_idx\u001b[0m \u001b[1;32min\u001b[0m \u001b[0mrange\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mnum_kernels\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
"\u001b[1;31mNameError\u001b[0m: name 'is_in_xvalidation_trainig' is not defined"
]
}
],
"source": [
"# this is the case for translation invariant kernels\n",
"# the first step should be just precomputing the kernel matrix itself, \\\n",
"# i.e. the loop over kernels is outside\n",
"\n",
"\n",
"\n",
"#s = sg.SubsetStack()\n",
"for repetition in range(num_repetitions):\n",
" # this is permuting X,Y effectively, without precomputing D\n",
" p1 = np.random.permutation(NX)\n",
" p2 = np.random.permutation(NY)+NX\n",
" p = np.hstack((p1,p2))\n",
" print \"p\", p\n",
" #s.add_subset(p)\n",
" \n",
" # hack for sklearn\n",
" kf_x = KFold(NX, n_folds=num_folds)\n",
" kf_y = KFold(NY, n_folds=num_folds)\n",
" trains_x = []\n",
" for train, _ in kf_x:\n",
" trains_x += [train]\n",
" trains_y = []\n",
" for train, _ in kf_y:\n",
" trains_y += [train]\n",
" \n",
" for fold in range(num_folds):\n",
" train_x = trains_x[fold]\n",
" train_y = trains_y[fold]+NX # important: offset\n",
" \n",
" print \"train_x\", train_x\n",
" print \"train_y\", train_y\n",
" \n",
" train = np.hstack((train_x, train_y))\n",
" #s.add_subset(train) # D is masked to be smaller now\n",
" \n",
" mmd_all = np.empty((num_kernels, num_null_samples+1, 3))\n",
"\n",
" \n",
" for null_sample in range(num_null_samples+1):\n",
" # only add permutation for null samples, last iteration computes MMD itself\n",
" if null_sample<num_null_samples:\n",
" # make sure this is thread safe, i.e. clone the subset stack\n",
" # and add different permutation to each\n",
" flattened = np.random.permutation(len(train_x)+len(train_y))#s.get_active_subset()\n",
" \n",
" # precompute index map\n",
" \n",
" # this iterates over FULL D matrix\n",
" for row_idx in range(NX):\n",
" for col_idx in range(NY):\n",
" # some offset stuff going on inside here\n",
" if is_in_xvalidation_trainig(train_x, train_y, row_idx, col_idx):\n",
" which_mmd_term = get_mmd_term_to_add(row_idx, col_idx)\n",
" for kernel_idx in range(num_kernels):\n",
" mmd_all[kernel_idx, num_null_samples, which_mmd_term] += \\\n",
" k(D[row_idx, col_idx], log_sigmas[kernel_idx])\n",
" \n",
" # for every kernel, perform test\n",
" \n",
" # compute test results, num_kernels-dim vector\n",
" # update type_2 error for each kernel"
]
},
{
"cell_type": "code",
"execution_count": 73,
"metadata": {
"collapsed": false
},
"outputs": [
{
"data": {
"text/plain": [
"(array([ 5, 6, 7, 8, 9, 10, 11, 12, 13, 14]), array([0, 1, 2, 3, 4]))"
]
},
"execution_count": 73,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"kf.__iter__().__iter__().next()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 2",
"language": "python",
"name": "python2"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 2
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2",
"version": "2.7.6"
}
},
"nbformat": 4,
"nbformat_minor": 0
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment