Created
June 28, 2016 15:26
-
-
Save karlnapf/1000e26a40fe74d08237841a2614080c to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 45, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"import numpy as np\n", | |
"import scipy as sp\n", | |
"from scipy.spatial.distance import pdist, squareform\n", | |
"import matplotlib.pyplot as plt\n", | |
"%matplotlib inline\n", | |
"import modshogun as sg\n", | |
"from sklearn.cross_validation import KFold" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 35, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"NX=10\n", | |
"NY=5\n", | |
"NXY = NX+NY\n", | |
"\n", | |
"D=2\n", | |
"X = np.random.randn(NX,D)\n", | |
"Y = np.random.randn(NY,D)+3" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 30, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"XY = np.vstack((X,Y))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 31, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAP4AAAD7CAYAAABKWyniAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAEL9JREFUeJzt3WuwldV9x/Hvz3qJAbzGSyMjR2K8jEa8RW3Uuqt1QJOo\no3XG2CSijS8aL0x0rGheeOwri01TRuMLI+K9TgWnko4XZMhxqp3YAOIFwWC9Yga8kBQwGQP674u9\ngRM8+5zNWvt5Nmb9PjN7Zl+e//4vOOd3nn151rMUEZhZWbbr9QDMrH4OvlmBHHyzAjn4ZgVy8M0K\n5OCbFWj7qhtI8veFZj0SERrq/sqD33RDm/sHgMYwdbtk9Fw/zGNPAqdnPPdwjhjmsfuAb1fUd+dh\nHrsLmNz+4b5GWss37hvmwYeBc4d5/N20nsC+cU7bx9b2T2dM/5S2j6/UAcl9WTZkhppu6Ycr+ts+\n/O7BY5Lb7v3E2vYP3tsP32nTd1L78fqlvlmBsoIvaZKkZZJ+Jenabg3KzKqVHHxJ2wG3AhOBw4Bv\nSTpk656lL7V9pvE96jvc24AqHdmDnof2oCfs2Di+J305rtGbvkek9c3Z4x8HLI+INyNiPfAgcPbW\nPUVfRvscX+pRXwe/ajs1TuhJX45v9KbvhLS+OcHfD3h70O0VrfvMbBtX06f6A4Ou99G7Pb3Zn7Dn\nB+CFgY42zQn+O8D+g26Pbd03hEZGGzPryITGH7/0v//GtpvmvNT/JXCgpHGSdgQuAOZkPJ+Z1SR5\njx8RH0u6HJhL8w/IjIhY2rWRmVllst7jR8TjwMFdGouZ1cRH7pkVyME3K1BNX+elTrZZ09VRdC7n\n4JMVGbWXZtQuTC/9RmLdravTe2b8bFfOzjny8oP00gf3TC792Q3fTO/7UnppO97jmxXIwTcrkINv\nViAH36xADr5ZgRx8swI5+GYFcvDNCuTgmxXIwTcrkINvViAH36xADr5ZgRx8swLVNC13uHXstkW/\nz6jdIaM2Y8ooG9JL389o2wu7ffb6jmGY9e9GUkFKvcc3K5CDb1agnLXzxkqaL2mJpBclXdnNgZlZ\ndXLePWwAroqIxZJGAwslzY2IZV0am5lVJHmPHxErI2Jx6/o6YCleO8/sM6Er7/El9dFckvXZbjyf\nmVUrO/itl/mzgCmtPb+ZbeOyviGUtD3N0N8bEY+03/LJQdfH07v16c3+hC0fgFcHOto099CAO4GX\nI2L68JudntnGzEb05UbzstETFayWK+lE4G+BUyU9J2mRpEmpz2dm9clZLfcZ4M+6OBYzq4mP3DMr\nkINvViAH36xANU3LTZWzam3O1No3Mmqvyqi9P6P2ovTSB1NX+D0jvWfGsV6jT3gvuXYdGYeaNNJX\ny+3L+Z3KmHHdjvf4ZgVy8M0K5OCbFcjBNyuQg29WIAffrEAOvlmBHHyzAjn4ZgVy8M0K5OCbFcjB\nNyuQg29WIAffrECKiGobSAGPJlanTheFvFVrz82o/Zf00pP602uf/mly6V4xManuPd2Z3BMOTK48\nL+NXdvbF306uHTczfZGoN6anTzE/c8rspLrHdB4RoaEe8x7frEAOvlmBurGSznatU2vP6caAzKx6\n3djjTwFe7sLzmFlNsoIvaSxwJnBHd4ZjZnXI3eP/GLgGqParATPrqpwltL4OrIqIxYBaFzP7DMg5\nvfaJwFmSzgR2BsZIuicivvvpTe8bdP2I1sXMuumDgZdYPbCko21z1s67HrgeQNIpwNVDhx4g/aAJ\nM+vMno3D2bNx+Kbbr97472239ff4ZgXqyko6EfEU8FQ3nsvMquc9vlmBHHyzAjn4ZgXaxlfLvTSj\n9oOM2oxVa7Om1mbUfiG99j09mVY4Nr0nK9Ynlzb4++Ta2U+nf8N0IQ8k17I6vXQBx6YXt+E9vlmB\nHHyzAjn4ZgVy8M0K5OCbFcjBNyuQg29WIAffrEAOvlmBHHyzAjn4ZgVy8M0K5OCbFcjBNytQTavl\n/jyxekxG5w0ZtV/JqM2Y0vuFjGnI7/cnl45ed1lS3brR/5HcExrJlb9fd1By7c4L0n/fB045Prn2\n5/qf5NqzIm2l3WO01KvlmtlmDr5ZgXLXzttV0kOSlkpaIin9tZCZ1Sb31FvTgUcj4nxJ2wOf78KY\nzKxiycGXtAtwckRMBoiIDcCaLo3LzCqU81L/AOB9STMlLZJ0u6SduzUwM6tOzkv97YGjgcsiYoGk\nfwWmAjd8etO7Bl0/snUxs25aMPAhCwd+19G2OcFfAbwdEQtat2cB1w696eSMNmbWiWMbozi2MWrT\n7dtvfL/ttskv9SNiFfC2pI1HU5wGvJz6fGZWn9xP9a8E7pe0A/AacHH+kMysalnBj4jnga92aSxm\nVhMfuWdWIAffrEAOvlmB6lktt6+RVveNjJ7tv8kY2YMrkkv3ionJtcmr1pI+tRZg3eifpBUu60/u\nSUbpJaNmpBdn/E7dtvb7ybXT4vXk2r4ZqV+WDTkjF/Ae36xIDr5ZgRx8swI5+GYFcvDNCuTgmxXI\nwTcrkINvViAH36xADr5ZgRx8swI5+GYFcvDNClTTopn3Jlav7upYOndGRm3Goplj+9NrV/w0vXZZ\n4mKdh/Sn92S/5MpT48+Ta+efkT49b4//fCe59oOHxybXHnP+fyXVLdLJXjTTzDZz8M0K5OCbFSh3\ntdzrWqvkviDpfkk7dmtgZlad5OBLGgdcChwVEUfQPI3XBd0amJlVJ+ece2uAPwCjJH1Cc4nsX3dl\nVGZWqZwltH4D/Ah4C3gH+G1EzOvWwMysOsl7fEnjgR8A44D/A2ZJujAiHvj01g8Pun5o62Jm3bR2\n4DnWDjzX0bY5L/WPBZ6JiNUAkh4GvgYMEfxzM9qYWSfGNI5iTOOoTbdX3jiz7bY5n+q/Apwg6XOS\nRHO13KUZz2dmNcl5j/88cA+wEHie5tn7b+/SuMysQrmr5d4M3NylsZhZTXzknlmBHHyzAtWzaCbv\nJtat6eooOvdsRu2B6aUr1mf0baSX9qcWpk+tbR76kWb+jMRpxACPpy9euXrqAcm1026+PLl20XUn\nJde24z2+WYEcfLMCOfhmBXLwzQrk4JsVyME3K5CDb1YgB9+sQA6+WYEcfLMCOfhmBXLwzQrk4JsV\nyME3K1Atq+XuG/+bVLty9vj0xrull44+4b3k2omjnkiubTCQXPu9D2ck114yKq12FXsn95w/I33V\nWr7Xn17LP2TUTsuo/UpG7drEuou9Wq6ZbebgmxVoxOBLmiFplaQXBt23u6S5kl6R9ISkXasdppl1\nUyd7/JnAxC3umwrMi4iDgfnAdd0emJlVZ8TgR8TTwG+2uPts4O7W9buBc7o8LjOrUOp7/L0jYhVA\nRKyEjI93zax23fpwr9rvBM2sq1JPr71K0j4RsUrSvoxw/uy1/dM3Xd+xcTw7NU5IbGtm7S1rXUbW\nafDVumw0B5gM/BNwEfDIcMVj+qd02MbM0h3SumzUPpadfJ33APDfwEGS3pJ0MXATcLqkV2iukntT\n1njNrFYj7vEj4sI2D/11l8diZjXxkXtmBXLwzQrk4JsVqJZpufBJYvXqro6lcxmr9E5OX1GVp9NL\nuSOjNnWGbM4irhmr1sI+GbU5U2uvSC+9fM/02lsfSyw809NyzWwzB9+sQA6+WYEcfLMCOfhmBXLw\nzQrk4JsVyME3K5CDb1YgB9+sQA6+WYEcfLMCOfhmBXLwzQpUz7TcZYk9HsxonLFaLo300nETOjvL\n6VAu5IHk2omkr9J7G99Pqpv3cfrZ11ZP3S+5ln/uT6/NmVrLLRm1Z2TUrkis+xtPyzWzzRx8swKl\nrpY7TdJSSYslzZa0S7XDNLNuSl0tdy5wWEQcCSzHq+WafaYkrZYbEfMiYuOJ9H4BjK1gbGZWkW68\nx78ESD0boJn1QFbwJf0QWB8R6d9DmVntUlfLRdJk4Ezg1BE3vqV/8/XjGnB8I7WtmbX1ErCkoy2T\nVsuVNAm4BvjLiPhoxOor+jtsY2bpDm9dNnqo7Zapq+XeAowGnpS0SNJteQM2szqlrpY7s4KxmFlN\nfOSeWYEcfLMCOfhmBaplWu67MTqp9md8M7nvGNYm1/bxRnLtV6e/lFybszhw/z+m114ceyXVjXvo\nveSe086/PLn2WjWSa7n8vPTaW59Nr806xu3cxLoJnpZrZps5+GYFcvDNCuTgmxXIwTcrkINvViAH\n36xADr5ZgRx8swI5+GYFcvDNCuTgmxXIwTcrkINvVqB6Vst9PLFHxgzX9PMHAxvSS8+4+uHk2gUc\nm1z7OJOSa4+Z8XJS3dF/93Ryz0XXnZRcy013pdeyT0bt7zJqv5xRm/o7daOn5ZrZZg6+WYGSVssd\n9NjVkj6RtEc1wzOzKqSuloukscDpwJvdHpSZVStptdyWH9NcTcfMPmOS3uNLOgt4OyJe7PJ4zKwG\nW/2ll6SdgetpvszfdPewRff2b75+RAMmNLa2rZmN6I3WZWQp33Z/CegDnpckYCywUNJxEfHukBXf\n6U9oY2Zbp6912eiptltu9Wq5EfESsO+mB6TXgaMjYqjPAcxsG5S6Wu5gwUgv9c1sm9LJp/oXRsQX\nI2KniNg/ImZu8fj4iEhbA+b5gaSybMt70/eDgZxjkNMtGPiw9p5rB56rvWfTsh717c3PttP39Fvq\n7ZF7Lwz0pu+rvem7emBJT/ouHMg5xjxNecHvzc/2sxl8M+sJB9+sQPVMyzWznmg3Lbfy4JvZtscv\n9c0K5OCbFahnwZc0SdIySb+SdG1NPcdKmi9piaQXJV1ZR99W7+0kLZI0p8aeu0p6SNLS1r/5+Jr6\nXtfq94Kk+yXtWFGfT50rQtLukuZKekXSE5J2ranvtNb/82JJsyXtUnXPQY9t9XkxehJ8SdsBt9Kc\n538Y8C1Jh9TQegNwVUQcBvwFcFlNfQGmAGknt0s3HXg0Ig4FJgBLq24oaRxwKXBURBxB87DwCypq\nN9S5IqYC8yLiYGA+cF1NfecCh0XEkcDyCvp29bwYvdrjHwcsj4g3I2I98CBwdtVNI2JlRCxuXV9H\nMwj7Vd239cM5E7ij6l6Deu4CnLzxSMuI2BARa2povQb4AzBK0vbA54FfV9Gozbkizgbubl2/Gzin\njr4RMS8iPmnd/AXNyWuV9mxJOi9Gr4K/H/D2oNsrqCGAg0nqA44Enq2h3cYfTp1foRwAvC9pZust\nxu2tKdWVak3W+hHwFvAO8NuImFd130H2johVrbGsBPausfdGlwCPVd0k57wYRX64J2k0MAuY0trz\nV9nr68Cq1iuNTbMca7A9cDTwk4g4mua5oadW3VTSeOAHwDjgi8BoSRdW3XcYtX5fLemHwPqIeKDi\nPhvPi3HD4Ls7re9V8N8B9h90e2zrvsq1Xn7OAu6NiEdqaHkicJak14B/A/5K0j019F1Bc2+woHV7\nFs0/BFU7FngmIlZHxMc0Twr/tRr6brRK0j4AkvYFhj5HRAUkTab5lq6OP3SDz4vxOpvPi9HRK5xe\nBf+XwIGSxrU+8b0AqOvT7juBlyNieh3NIuL61qzG8TT/nfMj4rs19F0FvC3poNZdp1HPh4uvACdI\n+lzrRC2nUe2Hilu+ipoDTG5dvwio6o/7H/WVNInm27mzIuKjqntGxEsRsW9rduwBNP/QH9X2ZDhb\nioieXIBJNH9JlgNTa+p5IvAxsBh4DlgETKrx33wKMKfGfhNo/pFdTHPPu2tNfa+hOV3tBZofsO1Q\nUZ8HaH5w+BHNzxQuBnYH5rV+t+YCu9XUdznNT9YXtS63Vd1zi8dfA/bo9Pl8yK5ZgYr8cM+sdA6+\nWYEcfLMCOfhmBXLwzQrk4JsVyME3K5CDb1ag/wesFHOm67vFaQAAAABJRU5ErkJggg==\n", | |
"text/plain": [ | |
"<matplotlib.figure.Figure at 0x7f51d8c90e90>" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
} | |
], | |
"source": [ | |
"D = squareform(pdist(XY, 'sqeuclidean'))\n", | |
"plt.imshow(D, interpolation='nearest');" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 50, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"20.1246117975\n" | |
] | |
} | |
], | |
"source": [ | |
"num_repetitions = 3\n", | |
"num_folds = 3\n", | |
"log_sigmas = np.linspace(-2,2,3)\n", | |
"num_kernels = len(log_sigmas)\n", | |
"num_null_samples = 5\n", | |
"\n", | |
"params = (num_repetitions, num_folds, num_kernels, num_null_samples, 3)\n", | |
"print np.sqrt(np.prod(params))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 79, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"p [ 9 4 5 6 7 3 1 2 8 0 11 14 12 13 10]\n", | |
"train_x [4 5 6 7 8 9]\n", | |
"train_y [12 13 14]\n" | |
] | |
}, | |
{ | |
"ename": "NameError", | |
"evalue": "name 'is_in_xvalidation_trainig' is not defined", | |
"output_type": "error", | |
"traceback": [ | |
"\u001b[1;31m---------------------------------------------------------------------------\u001b[0m", | |
"\u001b[1;31mNameError\u001b[0m Traceback (most recent call last)", | |
"\u001b[1;32m<ipython-input-79-7d8a1e18547f>\u001b[0m in \u001b[0;36m<module>\u001b[1;34m()\u001b[0m\n\u001b[0;32m 49\u001b[0m \u001b[1;32mfor\u001b[0m \u001b[0mcol_idx\u001b[0m \u001b[1;32min\u001b[0m \u001b[0mrange\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mNY\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 50\u001b[0m \u001b[1;31m# some offset stuff going on inside here\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 51\u001b[1;33m \u001b[1;32mif\u001b[0m \u001b[0mis_in_xvalidation_trainig\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mtrain_x\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mtrain_y\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mrow_idx\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mcol_idx\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 52\u001b[0m \u001b[0mwhich_mmd_term\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mget_mmd_term_to_add\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mrow_idx\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mcol_idx\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 53\u001b[0m \u001b[1;32mfor\u001b[0m \u001b[0mkernel_idx\u001b[0m \u001b[1;32min\u001b[0m \u001b[0mrange\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mnum_kernels\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", | |
"\u001b[1;31mNameError\u001b[0m: name 'is_in_xvalidation_trainig' is not defined" | |
] | |
} | |
], | |
"source": [ | |
"# this is the case for translation invariant kernels\n", | |
"# the first step should be just precomputing the kernel matrix itself, \\\n", | |
"# i.e. the loop over kernels is outside\n", | |
"\n", | |
"\n", | |
"\n", | |
"#s = sg.SubsetStack()\n", | |
"for repetition in range(num_repetitions):\n", | |
" # this is permuting X,Y effectively, without precomputing D\n", | |
" p1 = np.random.permutation(NX)\n", | |
" p2 = np.random.permutation(NY)+NX\n", | |
" p = np.hstack((p1,p2))\n", | |
" print \"p\", p\n", | |
" #s.add_subset(p)\n", | |
" \n", | |
" # hack for sklearn\n", | |
" kf_x = KFold(NX, n_folds=num_folds)\n", | |
" kf_y = KFold(NY, n_folds=num_folds)\n", | |
" trains_x = []\n", | |
" for train, _ in kf_x:\n", | |
" trains_x += [train]\n", | |
" trains_y = []\n", | |
" for train, _ in kf_y:\n", | |
" trains_y += [train]\n", | |
" \n", | |
" for fold in range(num_folds):\n", | |
" train_x = trains_x[fold]\n", | |
" train_y = trains_y[fold]+NX # important: offset\n", | |
" \n", | |
" print \"train_x\", train_x\n", | |
" print \"train_y\", train_y\n", | |
" \n", | |
" train = np.hstack((train_x, train_y))\n", | |
" #s.add_subset(train) # D is masked to be smaller now\n", | |
" \n", | |
" mmd_all = np.empty((num_kernels, num_null_samples+1, 3))\n", | |
"\n", | |
" \n", | |
" for null_sample in range(num_null_samples+1):\n", | |
" # only add permutation for null samples, last iteration computes MMD itself\n", | |
" if null_sample<num_null_samples:\n", | |
" # make sure this is thread safe, i.e. clone the subset stack\n", | |
" # and add different permutation to each\n", | |
" flattened = np.random.permutation(len(train_x)+len(train_y))#s.get_active_subset()\n", | |
" \n", | |
" # precompute index map\n", | |
" \n", | |
" # this iterates over FULL D matrix\n", | |
" for row_idx in range(NX):\n", | |
" for col_idx in range(NY):\n", | |
" # some offset stuff going on inside here\n", | |
" if is_in_xvalidation_trainig(train_x, train_y, row_idx, col_idx):\n", | |
" which_mmd_term = get_mmd_term_to_add(row_idx, col_idx)\n", | |
" for kernel_idx in range(num_kernels):\n", | |
" mmd_all[kernel_idx, num_null_samples, which_mmd_term] += \\\n", | |
" k(D[row_idx, col_idx], log_sigmas[kernel_idx])\n", | |
" \n", | |
" # for every kernel, perform test\n", | |
" \n", | |
" # compute test results, num_kernels-dim vector\n", | |
" # update type_2 error for each kernel" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 73, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"(array([ 5, 6, 7, 8, 9, 10, 11, 12, 13, 14]), array([0, 1, 2, 3, 4]))" | |
] | |
}, | |
"execution_count": 73, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"kf.__iter__().__iter__().next()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 2", | |
"language": "python", | |
"name": "python2" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 2 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython2", | |
"version": "2.7.6" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 0 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment