Last active
December 28, 2015 16:12
-
-
Save raghavrv/bac002a5668f8fe88db2 to your computer and use it in GitHub Desktop.
Test for non-reset upon partial fit/reset upon fit
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"metadata": { | |
"name": "", | |
"signature": "sha256:2190a1d956e09864439d4398f0e615fa400757b27467843876fd36eac70f0bc7" | |
}, | |
"nbformat": 3, | |
"nbformat_minor": 0, | |
"worksheets": [ | |
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"collapsed": false, | |
"input": [ | |
"from sklearn.base import BaseEstimator, ClassifierMixin, MetaEstimatorMixin\n", | |
"from sklearn.svm import LinearSVC\n", | |
"from sklearn.utils import check_consistent_length\n", | |
"from sklearn.externals.joblib import Parallel, delayed\n", | |
"from sklearn.multiclass import _fit_binary, _fit_ovo_binary, OneVsOneClassifier\n", | |
"from sklearn.linear_model import SGDClassifier\n", | |
"from sklearn import datasets\n", | |
"from sklearn.preprocessing import scale\n", | |
"import numpy as np\n", | |
"import functools\n", | |
"\n", | |
"%matplotlib inline\n", | |
"import matplotlib.pyplot as plt" | |
], | |
"language": "python", | |
"metadata": {}, | |
"outputs": [], | |
"prompt_number": 1 | |
}, | |
{ | |
"cell_type": "code", | |
"collapsed": false, | |
"input": [ | |
"from sklearn.utils.testing import assert_raises, assert_equal, assert_array_almost_equal, assert_almost_equal" | |
], | |
"language": "python", | |
"metadata": {}, | |
"outputs": [], | |
"prompt_number": 2 | |
}, | |
{ | |
"cell_type": "code", | |
"collapsed": false, | |
"input": [ | |
"iris = datasets.load_iris()\n", | |
"X, y, bX, by, est_pfit = (None,)*5\n", | |
"\n", | |
"def init_data(init, shuffle):\n", | |
" global X, y, bX, by, est_pfit\n", | |
"\n", | |
" \n", | |
" if init:\n", | |
" X, y = iris.data, iris.target\n", | |
" X = scale(X)\n", | |
"\n", | |
" idx = np.arange(X.shape[0])\n", | |
" if shuffle:\n", | |
" rng = np.random.RandomState(0)\n", | |
" rng.shuffle(idx)\n", | |
"\n", | |
" # Shuffle X and y and split them into 10 smaller batches\n", | |
" bX = [ X[idx[i:i+10]] for i in range(0, 100, 10) ]\n", | |
" by = [ y[idx[i:i+10]] for i in range(0, 100, 10) ]\n", | |
"\n", | |
" state = (bX, by)\n", | |
" return state\n", | |
"\n", | |
"def test_error_upon_feature_size_change():\n", | |
" global X, y, bX, by, est_pfit\n", | |
" \n", | |
" # 1st batch of 10 with 4 features\n", | |
" est_pfit.partial_fit(bX[0], by[0], classes=np.unique(y))\n", | |
"\n", | |
" # Attempting batch of 2 with 4 features\n", | |
" assert_raises(ValueError, est_pfit.partial_fit, [[1, 0, 1.5], [1.5, 2, 1]], [1,2])\n", | |
"\n", | |
"def initialize():\n", | |
" global X, y, bX, by, est_pfit\n", | |
" \n", | |
" est_pfit = SGDClassifier(n_iter = 1, shuffle=False, loss=\"log\")\n", | |
" \n", | |
"def part_fit():\n", | |
" global X, y, bX, by, est_pfit\n", | |
"\n", | |
" # 1st batch of 10 with 4 features\n", | |
" est_pfit.partial_fit(bX[0], by[0], classes=np.unique(y))\n", | |
"\n", | |
" # Partially fit all other batches\n", | |
" for xi, yi in zip(bX[1:], by[1:]):\n", | |
" est_pfit.partial_fit(xi, yi)\n", | |
"\n", | |
" return round((np.where(y == est_pfit.predict(X))[0].shape[0])*100.0/150.0)\n", | |
"\n", | |
" \n", | |
"def full_fit():\n", | |
" global X, y, bX, by, est_pfit\n", | |
" \n", | |
" est_pfit.fit(X, y)\n", | |
" return round((np.where(y == est_pfit.predict(X))[0].shape[0])*100.0/150.0)" | |
], | |
"language": "python", | |
"metadata": {}, | |
"outputs": [], | |
"prompt_number": 3 | |
}, | |
{ | |
"cell_type": "code", | |
"collapsed": false, | |
"input": [ | |
"# for i in range(10):\n", | |
"# initialize()\n", | |
"# init_data(init=True, shuffle=False)\n", | |
"# print \"F: \", full_fit(), \"%\"\n", | |
"# print \"F: \", full_fit(), \"%\"\n", | |
"\n", | |
"initialize()\n", | |
"init_data(init=True, shuffle=False)\n", | |
"points = []\n", | |
"\n", | |
"for i in range(1000):\n", | |
" init_data(init=False, shuffle=True)\n", | |
" #print \"P: \", part_fit(), \"%\"\n", | |
" \n", | |
" points.append(part_fit())\n", | |
"\n", | |
"plt.plot(points)\n", | |
"plt.show()" | |
], | |
"language": "python", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"metadata": {}, | |
"output_type": "display_data", | |
"png": "iVBORw0KGgoAAAANSUhEUgAAAX4AAAEACAYAAAC08h1NAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAFfBJREFUeJzt3WuwZWV95/HvDxrkonQjRXfXiHTsGJWkoogXyCRkdkmr\nPc4ECVUSNKUg2mOVRWn5AqFNTfWpjKnopCTjZHQmBHUw4w1vBdQwQgjusZxScwFEmqsSsGHSzSRy\nDiJ008B/XuzV9OHQDaf3Xmdfzv5+qnadtdZel2c/3fU7z/mvy05VIUmaHgeNugGSpOEy+CVpyhj8\nkjRlDH5JmjIGvyRNGYNfkqbMswZ/ks8k2ZHk5nnLjk5ybZI7klyTZOW89zYnuSvJbUnetJQNlyT1\n57lG/J8D3rxg2UXAdVX1cuB6YDNAkl8FzgJOAP418Okkabe5kqRBPWvwV9V3gQcXLH4rcFkzfRlw\nRjN9OvDlqnq8qu4B7gJe315TJUlt6KfGv7qqdgBU1XZgdbP8RcC2eevd3yyTJI2RNk7u+swHSZog\nK/rYZkeSNVW1I8la4IFm+f3Ai+etd1yz7GmS+ItCkvpQVa2cN13MiD/Na48rgXOb6XOAK+YtPzvJ\noUleArwU+Jt97bCqfFWxZcuWkbdhXF72hX1hXzz7q03POuJP8kWgAxyT5KfAFuBjwFeTnAfcS+9K\nHqrq1iSXA7cCu4H3V9utlSQN7FmDv6resZ+3Nuxn/T8G/njQRkmSlo537o5Qp9MZdRPGhn2xl32x\nl32xNDLsakwSK0CSdICSUEM8uStJWkYMfkmaMga/JE0Zg1+Spkw/d+5qmbv7bjjvPDj44FG3ZPl6\n7DH4gz+AjRtH3RJNI4Nfz3DLLbBtG1xyyahbsnx94hPw3e8a/BoNg1/PMDcHv/EbcNppo27J8rV1\nK9x116hboWlljV/PMDcHK1c+93rq38qVvX6WRsHg1zMY/EvP4NcoLetSz89/DvfdN+pWTJ4bboCT\nTx51K5a3Vavgttt6L2nYlnXwX3AB/PmfwyteMeqWTJaHH4ZNm0bdiuVt/XrYvRvOPHPULdE0WtbB\n/8//DJdfDm9726hbIj3d8cfDP/zDqFuhSZJWntLTs6xr/LOz1qolaaFlHfyepJSkZ1r2wb9q1ahb\nIUnjZVnU+O+/Hz72MVj4mP8773TEL0kLLYvg//734ctfhi1bnr78s5+FNWtG0yZJGld9B3+SDwLv\nbWb/oqr+c5ItwCbggWb5R6rqWwO28TnNzsLv/A6cf/5SH0mSJl9fwZ/k14D3AK8FHgf+V5L/2bx9\ncVVd3FL7FsVaviQtXr8j/hOAH1TVLoAk3wH23IrS4tWmi+PVO5K0eP1e1XMLcGqSo5McAbwFOA4o\n4PwkNyW5NMlQ4tjr9SVp8foa8VfV7Uk+DvwV8DBwI/AE8F+B/1BVleSjwMX0SkJPMzMz89R0p9Oh\n0+kcwLHhoH38uvrGNw7kE0jSeOt2u3S73SXZd2rhNZD97CT5I2BbVf23ecvWAVdV1SsXrFuDHPOR\nR+CYY+DRR/vehSRNnCRUVSul9L5v4EpybPPzeOB3gS8mWTtvlTPplYRa5YlcSRrMINfxfz3JC4Hd\nwPur6qEk/yXJicCTwD3A+1po49N4IleSBtN38FfVb+9j2bsGa85z80SuJA1m4p7VY6lHkgYzccHv\niF+SBjNxwW+NX5IGM3HBPztrqUeSBjFxwe+IX5IGY/BL0pSZmOCvgl/5FfjUpwx+SRrExAT/o4/C\nj3/cm7bGL0n9m5jgn5vbO+2IX5L6NzHBPzsLaR5P5Ihfkvo3McE/Nwfr1vWmHfFLUv8mKviPP743\nbfBLUv8mKvhXr4ZTToGjjhp1ayRpck1M8O+5Y/d734ODDx51ayRpck1M8PtUTklqx0QFv7V9SRrc\nxAS/j2OWpHZMRPDffnvvUQ2WeiRpcBMR/DfeCOvXwxlnjLolkjT5+g7+JB9M8qPm9YFm2dFJrk1y\nR5JrkrRSnNm5E049FV7wgjb2JknTra/gT/JrwHuA1wInAv82yS8DFwHXVdXLgeuBzW00ctcuOOyw\nNvYkSep3xH8C8IOq2lVVTwDfAc4ETgcua9a5DGilOLNzp8EvSW3pN/hvAU5tSjtHAG8BXgysqaod\nAFW1HVjdRiN37YLnPa+NPUmSVvSzUVXdnuTjwF8BDwM3Ak/sa9V9bT8zM/PUdKfTodPpPOvxHPFL\nmjbdbpdut7sk+07VPrP5wHaS/BGwDfgg0KmqHUnWAt+uqhMWrFsHeszNm3vP59ncyhkDSZo8Saiq\ntLGvQa7qObb5eTzwu8AXgSuBc5tVzgGuGLB9gKUeSWpTX6WexteTvBDYDby/qh5qyj+XJzkPuBc4\nq41GWuqRpPb0HfxV9dv7WPYzYMNALdoHg1+S2jMRd+5a6pGk9kxE8Dvil6T2TETwe+euJLVnIoJ/\n505LPZLUlokJfkf8ktQOg1+SpsxEBL9X9UhSeyYi+B3xS1J7Jib4HfFLUjsmIvi9nFOS2jMRwW+p\nR5LaMzHBb6lHktox9sG/ezc8+SQceuioWyJJy8PYB//cHKxcCWnl6wckSWMf/LOzveCXJLVj7IN/\nbg5WrRp1KyRp+ZiI4HfEL0ntGfvgt9QjSe0a++C31CNJ7eo7+JNsTrI1yc1JvpDkeUm2JLkvyQ3N\na+OgDbTUI0nt6uvL1pOsAzYBr6iqx5J8BTi7efviqrq4rQbOzjril6Q29Tvifwh4DDgyyQrgCOD+\n5r1Wr7h3xC9J7eor+KvqQeATwE/pBf5sVV3XvH1+kpuSXJpk4Mg2+CWpXf2WetYDHwLWAXPA15K8\nA/g08IdVVUk+ClwMvGfh9jMzM09NdzodOp3Ofo9l8EuaRt1ul263uyT7TlUd+EbJWcAbq2pTM/9O\n4OSqOn/eOuuAq6rqlQu2rQM55mmnwebNsGHDATdTkpaNJFRVK6X0fmv8dwCnJDksSYDTgNuSrJ23\nzpnALYM20BG/JLWrr1JPVf0wyeeBvweeAG4ALgE+k+RE4EngHuB9gzbQq3okqV19lXoGOuABlnqO\nPRa2boXVq5ewUZI05sah1DM0lnokqV1jHfxPPAGPP+63b0lSm8Y6+HfvhkMOGXUrJGl5MfglacqM\ndfA/9pjftStJbRvr4HfEL0ntM/glacoY/JI0ZQx+SZoyBr8kTRmDX5KmzFgHv5dzSlL7xjr4HfFL\nUvsMfkmaMga/JE0Zg1+SpozBL0lTxuCXpCkz1sG/cyccdtioWyFJy0vfwZ9kc5KtSW5O8oUkhyY5\nOsm1Se5Ick2Sgb40cdcug1+S2tZX8CdZB2wCXl1VrwRWAG8HLgKuq6qXA9cDmwdp3M6dfu2iJLWt\n3xH/Q8BjwJFJVgCHA/cDbwUua9a5DDhjkMZZ6pGk9vUV/FX1IPAJ4Kf0An+uqq4D1lTVjmad7cDq\nQRpnqUeS2rein42SrAc+BKwD5oCvJvl9oBasunAegJmZmaemO50OnU5nn8dxxC9pWnW7Xbrd7pLs\nO1X7zOZn3yg5C3hjVW1q5t8JnAK8AehU1Y4ka4FvV9UJC7atxR7zggvg2GPhwx8+4CZK0rKShKpK\nG/vqt8Z/B3BKksOSBDgNuBW4Eji3Wecc4IpBGmepR5La11epp6p+mOTzwN8DTwA3ApcALwAuT3Ie\ncC9w1iCNs9QjSe3rK/gBqupPgD9ZsPhnwIaBWjSPl3NKUvvG6s7dJ5+EM86AW27pzRv8ktS+sQr+\nnTvhiivgxht787t2weGHj7ZNkrTcjF3wA8zO7p13xC9J7RrL4J+b2zvvyV1JatdYBf+uXb2fe4Lf\nyzklqX1jFfz7GvFb6pGkdo1V8O8Z8e+p8Tvil6T2jVXwO+KXpKU3dsH/vOd5cleSltJYBf+uXbBm\njaUeSVpKYxX8O3f2gt9SjyQtnbEM/tlZqOqN+A1+SWrXWAX/ww/3nr//+OPwyCNw0EGwou/HyEmS\n9mWsgn9uDlatgpUrYccOR/uStBTGKvhnZ3uhv2pVL/g9sStJ7Rur4J+b6wX/nhG/wS9J7Rur4N8z\n4l+5ErZvt9QjSUthbIL/6qvh85+H445zxC9JS2lsgv+ee2DTJnjjG3s1/gceMPglaSn0dbFkkpcB\nXwEKCLAe+PfA0cAm4IFm1Y9U1bcWs8/ZWTjmmN70ypWwbZvBL0lLoa8Rf1XdWVWvrqqTgNcAvwC+\n2bx9cVWd1LwWFfqw98Qu7L2qxxq/JLWvjVLPBuAnVbWtmU8/O9lzYhes8UvSUmoj+H8P+NK8+fOT\n3JTk0iQrF7uTPTdvQe/nj3/sXbuStBQGitYkhwCnAxc1iz4N/GFVVZKPAhcD71m43czMzFPTnU6H\nTqfDww/D85/fW3bqqfDe98Kb3zxI6yRpcnW7Xbrd7pLsO1XV/8bJ6cD7q2rjPt5bB1xVVa9csLz2\ndcwNG+DCC3tX9UiSni4JVdVXKX2hQUs9b2demSfJ2nnvnQncstgd+ex9SRqOvks9SY6gd2L3381b\n/B+TnAg8CdwDvG+x+/PZ+5I0HH0Hf1U9Ahy7YNm7+t2fX7MoScMxNnfuWuqRpOEYm+B3xC9JwzFW\nwW+NX5KW3lgFvyN+SVp6YxP81vglaTjGIviffBIeewwOPXTULZGk5W8sgv/RR+HwwyGt3JMmSXo2\nYxH88x/JLElaWmMR/PMfySxJWlpjEfzzH8ksSVpaYxH8jvglaXjGIvit8UvS8IxN8FvqkaThGIvg\nt9QjScMzFsHviF+Shmdsgt8RvyQNx9gEvyN+SRqOsQj+hx6Co44adSskaTqMRfD7SGZJGp6+gj/J\ny5LcmOSG5udckg8kOTrJtUnuSHJNkkVV7g1+SRqevoK/qu6sqldX1UnAa4BfAN8ELgKuq6qXA9cD\nmxezP4NfkoanjVLPBuAnVbUNeCtwWbP8MuCMxexg1y6/dlGShqWN4P894IvN9Jqq2gFQVduB1YvZ\ngSN+SRqeFYNsnOQQ4HTgwmZRLVhl4TwAMzMzT013Oh127uw44pekebrdLt1ud0n2nap9ZvPiNk5O\nB95fVRub+duATlXtSLIW+HZVnbBgm1p4zLVr4aabej8lSc+UhKpq5XsKBy31vB340rz5K4Fzm+lz\ngCsWsxNLPZI0PH2P+JMcAdwLrK+qnzfLXghcDry4ee+sqppdsN0zRvyHHQYPPtj73l1J0jO1OeIf\nqNTT1wEXBH8VHHwwPP44HDQWt5NJ0vgZp1LPwL71rV74G/qSNBwjj9uf/ATOO2/UrZCk6THy4J+d\nhTVrRt0KSZoeIw9+H8ksScM1FsHvl7BI0vCMRfA74pek4Rl58PtF65I0XCMPfks9kjRcYxH8lnok\naXhGHvyWeiRpuEYe/I74JWm4Bnoef782boTdu+HRR3svH84mScMzkoe07fl+lu98pzfa//VfH2oT\nJGniTPzTOaFYuxb+8R+HemhJmljL4umchxwyqiNL0nQbWfAffPCojixJ021kwe/z9yVpNAx+SZoy\nfcdvkpVJvprktiRbk5ycZEuS+5Lc0Lw27m97Sz2SNBqDXMf/SeDqqnpbkhXAkcBG4OKquvi5Njb4\nJWk0+gr+JEcBp1bVuQBV9TgwlwRgUZcbWeqRpNHoN35fAvxTks81JZ1LkhzRvHd+kpuSXJpkv0/h\nccQvSaPRb/CvAE4CPlVVJwGPABcBnwbWV9WJwHZgvyUfR/ySNBr91vjvA7ZV1d81818DLqyq/zdv\nnb8Artr35jM88ADMzECn06HT6fTZDElanrrdLt1ud0n23fcjG5L8b2BTVd2ZZAtwBPCnVbW9ef9D\nwOuq6h0Ltiso3vAG+Ou/HrD1kjQl2nxkwyBX9XwA+EKSQ4C7gXcDf5bkROBJ4B7gffvb+Igj9veO\nJGkp9R38VfVD4HULFr9rsdsfeWS/R5YkDWJkp1gd8UvSaIwk+I85BjZsGMWRJUkjeR7/sI8pSZNu\nWTyPX5I0Gga/JE0Zg1+SpozBL0lTxuCXpClj8EvSlDH4JWnKGPySNGUMfkmaMga/JE0Zg1+SpozB\nL0lTxuCXpClj8EvSlDH4JWnKGPySNGX6Dv4kK5N8NcltSbYmOTnJ0UmuTXJHkmuSrGyzsZKkwQ0y\n4v8kcHVVnQC8CrgduAi4rqpeDlwPbB68ictXt9sddRPGhn2xl32xl32xNPoK/iRHAadW1ecAqurx\nqpoD3gpc1qx2GXBGK61cpvxPvZd9sZd9sZd9sTT6HfG/BPinJJ9LckOSS5IcAaypqh0AVbUdWN1W\nQyVJ7eg3+FcAJwGfqqqTgF/QK/Ms/BZ1v1VdksZMqg48m5OsAb5XVeub+d+iF/y/DHSqakeStcC3\nm3MA87f1l4Ek9aGq0sZ+VvR58B1JtiV5WVXdCZwGbG1e5wIfB84BrtjHtq00XJLUn75G/ABJXgVc\nChwC3A28GzgYuBx4MXAvcFZVzbbTVElSG/oOfknSZBrqnbtJNia5PcmdSS4c5rFHIclxSa5vbnD7\nUZIPNMv3e6Nbks1J7mpujHvT6FrfviQHNVeBXdnMT2U/wIHfALlc+6P5XFuT3JzkC0kOnZZ+SPKZ\nJDuS3Dxv2QF/9iQnNf13Z5L/tKiDV9VQXvR+yfwYWEevPHQT8IphHX8UL2AtcGIz/XzgDuAV9M6B\nfLhZfiHwsWb6V4Eb6Z17+aWmvzLqz9Fif3wI+B/Alc38VPZD8xn/O/DuZnoFsHLa+qPJgruBQ5v5\nr9A7NzgV/QD8FnAicPO8ZQf82YEfAK9rpq8G3vxcxx7miP/1wF1VdW9V7Qa+TO+Gr2WrqrZX1U3N\n9MPAbcBx7P9Gt9OBL1fvhrh7gLvo9dvES3Ic8BZ654X2mLp+gL5ugFyu/fEQ8BhwZJIVwOHA/UxJ\nP1TVd4EHFyw+oM/eXD35gqr622a9z7OIG2eHGfwvArbNm7+vWTYVkvwSvd/u32f/N7ot7KP7WT59\n9KfABTz93o5p7Ac48Bsgl2V/VNWDwCeAn9L7THNVdR1T1g8LrD7Az/4ielm6x6Jy1adzDkGS5wNf\nAz7YjPyn6ka3JP8G2NH89fNsl/Mu636YxxsggSTr6ZX/1gH/gt7I//eZsn54Dkvy2YcZ/PcDx8+b\nP65Ztqw1f8J+DfjLqtpzX8OO5iY4mj/VHmiW30/vUtg9lksf/SZwepK7gS8Bb0jyl8D2KeuHPe4D\ntlXV3zXzX6f3i2Da/l+8Fvg/VfWzqnoC+CbwL5m+fpjvQD97X30yzOD/W+ClSdYlORQ4G7hyiMcf\nlc8Ct1bVJ+ctu5LejW7w9BvdrgTObq5seAnwUuBvhtXQpVJVH6mq46t3p/fZwPVV9U7gKqaoH/Zo\n/pTfluRlzaI9N0BO1f8Lehc7nJLksCSh1w+3Ml39EJ7+V/ABffamHDSX5PVNH76Lfdw4+wxDPou9\nkd4/9l3ARaM+qz6Ez/ubwBP0rmC6Ebih6YMXAtc1fXEtsGreNpvpnbG/DXjTqD/DEvTJv2LvVT3T\n3A+vojcYugn4Br2reqauP+id99kK3EzvZOYh09IPwBeB/wvsonee493A0Qf62YHXAD9qcvWTizm2\nN3BJ0pTx5K4kTRmDX5KmjMEvSVPG4JekKWPwS9KUMfglacoY/JI0ZQx+SZoy/x8exBjF2Z8chAAA\nAABJRU5ErkJggg==\n", | |
"text": [ | |
"<matplotlib.figure.Figure at 0x7f7da92f3050>" | |
] | |
} | |
], | |
"prompt_number": 4 | |
}, | |
{ | |
"cell_type": "code", | |
"collapsed": false, | |
"input": [ | |
"initialize()\n", | |
"state = init_data(init=True, shuffle=True)\n", | |
"print \"F: \", full_fit(), \"%\" # n_iter = 1\n", | |
"print \"P: \", part_fit(), \"%\"\n", | |
"initialize()\n", | |
"bX, by = state\n", | |
"print \"P: \", part_fit(), \"%\"\n", | |
"print \"P: \", part_fit(), \"%\"" | |
], | |
"language": "python", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"stream": "stdout", | |
"text": [ | |
"F: 35.0 %\n", | |
"P: 67.0 %\n", | |
"P: 66.0 %\n", | |
"P: 67.0 %\n" | |
] | |
} | |
], | |
"prompt_number": 5 | |
}, | |
{ | |
"cell_type": "code", | |
"collapsed": false, | |
"input": [ | |
"# # Check if 2 Partial fits == 1 Fit\n", | |
"# est_pfit.partial_fit(X[:75], y[:75], classes = np.unique(y))\n", | |
"# est_pfit.partial_fit(X[75:], y[75:])\n", | |
"\n", | |
"c1 = est_pfit.coef_\n", | |
"\n", | |
"est_pfit.fit(X, y)\n", | |
"\n", | |
"c2 = est_pfit.coef_\n", | |
"\n", | |
"assert_array_almost_equal(c1, c2)" | |
], | |
"language": "python", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"ename": "AssertionError", | |
"evalue": "\nArrays are not almost equal to 6 decimals\n\n(mismatch 100.0%)\n x: array([[ -0.28176201, 2.08103825, -3.71397985, -7.04148776],\n [ 5.8815943 , -4.56869085, 9.96897204, -14.77459225],\n [ -4.6004667 , -1.31077272, 32.12912944, 40.1287937 ]])\n y: array([[ 2.55751194, 10.86712161, -8.38886032, -8.82729875],\n [ 2.93208873, 2.35362305, -12.97457253, -19.76706177],\n [ -8.42037216, -0.39783461, 17.44550535, 25.90058516]])", | |
"output_type": "pyerr", | |
"traceback": [ | |
"\u001b[1;31m---------------------------------------------------------------------------\u001b[0m\n\u001b[1;31mAssertionError\u001b[0m Traceback (most recent call last)", | |
"\u001b[1;32m<ipython-input-6-b807e23876c9>\u001b[0m in \u001b[0;36m<module>\u001b[1;34m()\u001b[0m\n\u001b[0;32m 9\u001b[0m \u001b[0mc2\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mest_pfit\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mcoef_\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 10\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 11\u001b[1;33m \u001b[0massert_array_almost_equal\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mc1\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mc2\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m", | |
"\u001b[1;32m/usr/lib/python2.7/dist-packages/numpy/testing/utils.pyc\u001b[0m in \u001b[0;36massert_array_almost_equal\u001b[1;34m(x, y, decimal, err_msg, verbose)\u001b[0m\n\u001b[0;32m 809\u001b[0m \u001b[1;32mreturn\u001b[0m \u001b[0maround\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mz\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mdecimal\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;33m<=\u001b[0m \u001b[1;36m10.0\u001b[0m\u001b[1;33m**\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m-\u001b[0m\u001b[0mdecimal\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 810\u001b[0m assert_array_compare(compare, x, y, err_msg=err_msg, verbose=verbose,\n\u001b[1;32m--> 811\u001b[1;33m header=('Arrays are not almost equal to %d decimals' % decimal))\n\u001b[0m\u001b[0;32m 812\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 813\u001b[0m \u001b[1;32mdef\u001b[0m \u001b[0massert_array_less\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mx\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0my\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0merr_msg\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;34m''\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mverbose\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mTrue\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", | |
"\u001b[1;32m/usr/lib/python2.7/dist-packages/numpy/testing/utils.pyc\u001b[0m in \u001b[0;36massert_array_compare\u001b[1;34m(comparison, x, y, err_msg, verbose, header)\u001b[0m\n\u001b[0;32m 642\u001b[0m names=('x', 'y'))\n\u001b[0;32m 643\u001b[0m \u001b[1;32mif\u001b[0m \u001b[1;32mnot\u001b[0m \u001b[0mcond\u001b[0m \u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 644\u001b[1;33m \u001b[1;32mraise\u001b[0m \u001b[0mAssertionError\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mmsg\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 645\u001b[0m \u001b[1;32mexcept\u001b[0m \u001b[0mValueError\u001b[0m \u001b[1;32mas\u001b[0m \u001b[0me\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 646\u001b[0m \u001b[1;32mimport\u001b[0m \u001b[0mtraceback\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", | |
"\u001b[1;31mAssertionError\u001b[0m: \nArrays are not almost equal to 6 decimals\n\n(mismatch 100.0%)\n x: array([[ -0.28176201, 2.08103825, -3.71397985, -7.04148776],\n [ 5.8815943 , -4.56869085, 9.96897204, -14.77459225],\n [ -4.6004667 , -1.31077272, 32.12912944, 40.1287937 ]])\n y: array([[ 2.55751194, 10.86712161, -8.38886032, -8.82729875],\n [ 2.93208873, 2.35362305, -12.97457253, -19.76706177],\n [ -8.42037216, -0.39783461, 17.44550535, 25.90058516]])" | |
] | |
} | |
], | |
"prompt_number": 6 | |
}, | |
{ | |
"cell_type": "code", | |
"collapsed": false, | |
"input": [ | |
"np.where(y != est_pfit.predict(X))[0].shape[0]" | |
], | |
"language": "python", | |
"metadata": {}, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"collapsed": false, | |
"input": [], | |
"language": "python", | |
"metadata": {}, | |
"outputs": [] | |
} | |
], | |
"metadata": {} | |
} | |
] | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment