Created
July 9, 2014 16:10
-
-
Save bmcfee/2ab57b0ca623d4429ed3 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| { | |
| "metadata": { | |
| "name": "", | |
| "signature": "sha256:588089aeeb2e4764a0cabcaf5b6cd63fc58a162eca884c0a1346d0cd96e149c3" | |
| }, | |
| "nbformat": 3, | |
| "nbformat_minor": 0, | |
| "worksheets": [ | |
| { | |
| "cells": [ | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "# Generator pooling and multiplexing\n", | |
| "\n", | |
| "Components:\n", | |
| "\n", | |
| " - GeneratorSeed\n", | |
| " \n", | |
| " Essentially a generator comprehension object. Can be used to repeatedly construct a generator from seed parameters.\n", | |
| " \n", | |
| " \n", | |
| " - stochastic multiplexor \n", | |
| " \n", | |
| " generates samples from a collection of generators\n", | |
| " \n", | |
| " for i in 1, 2, ...\n", | |
| " \n", | |
| " select a generator at random from the pool\n", | |
| " \n", | |
| " " | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "# What can we say about this?\n", | |
| "\n", | |
| "TODO" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "collapsed": false, | |
| "input": [ | |
| "import collections\n", | |
| "import numpy as np\n", | |
| "import scipy" | |
| ], | |
| "language": "python", | |
| "metadata": {}, | |
| "outputs": [], | |
| "prompt_number": 1 | |
| }, | |
| { | |
| "cell_type": "code", | |
| "collapsed": false, | |
| "input": [ | |
| "class GeneratorSeed(object):\n", | |
| " '''A wrapper class for reusable generators.\n", | |
| " \n", | |
| " :usage:\n", | |
| " >>> # make a generator\n", | |
| " >>> def my_generator(n):\n", | |
| " for i in range(n):\n", | |
| " yield i\n", | |
| " >>> GS = GeneratorSeed(my_generator, 5)\n", | |
| " >>> for i in GS.generate():\n", | |
| " print i\n", | |
| " \n", | |
| " >>> # Or with a maximum number of items\n", | |
| " >>> for i in GS.generate(max_items=3):\n", | |
| " print i\n", | |
| " \n", | |
| " :parameters:\n", | |
| " - generator : function or iterable\n", | |
| " Any generator function or iterable python object\n", | |
| " \n", | |
| " - *args, **kwargs\n", | |
| " Additional positional arguments or keyword arguments to pass through to ``generator()``\n", | |
| " '''\n", | |
| " \n", | |
| " def __init__(self, generator, *args, **kwargs):\n", | |
| " \n", | |
| " self.generator = generator\n", | |
| " self.args = args\n", | |
| " self.kwargs = kwargs\n", | |
| " \n", | |
| " \n", | |
| " def generate(self, max_items=None):\n", | |
| " '''Instantiate the generator\n", | |
| " \n", | |
| " :parameters:\n", | |
| " - max_items : None or int > 0\n", | |
| " Maximum number of items to yield. If ``None``, exhaust the generator.\n", | |
| " '''\n", | |
| " \n", | |
| " if max_items is None:\n", | |
| " max_items = np.inf\n", | |
| " \n", | |
| " # If it's a function, create the stream.\n", | |
| " # If it's iterable, use it directly.\n", | |
| " \n", | |
| " if hasattr(self.generator, '__call__'):\n", | |
| " my_stream = self.generator(*(self.args), **(self.kwargs))\n", | |
| " elif isinstance(self.generator, collections.Iterable):\n", | |
| " my_stream = self.generator\n", | |
| " else:\n", | |
| " raise ValueError('generator is neither a generator nor iterable.')\n", | |
| " \n", | |
| " for i, x in enumerate(my_stream):\n", | |
| " if i >= max_items:\n", | |
| " break\n", | |
| " yield x" | |
| ], | |
| "language": "python", | |
| "metadata": {}, | |
| "outputs": [], | |
| "prompt_number": 2 | |
| }, | |
| { | |
| "cell_type": "code", | |
| "collapsed": false, | |
| "input": [ | |
| "def categorical_sample(weights):\n", | |
| " '''Sample from a categorical distribution.\n", | |
| " \n", | |
| " :parameters:\n", | |
| " - weights : np.array, shape=(n,)\n", | |
| " The distribution to sample from. Must be non-negative and sum to 1.0.\n", | |
| " \n", | |
| " :returns:\n", | |
| " - k : int in [0, n)\n", | |
| " The sample\n", | |
| " '''\n", | |
| " \n", | |
| " return np.flatnonzero(np.random.multinomial(1, weights))[0]" | |
| ], | |
| "language": "python", | |
| "metadata": {}, | |
| "outputs": [], | |
| "prompt_number": 3 | |
| }, | |
| { | |
| "cell_type": "code", | |
| "collapsed": false, | |
| "input": [ | |
| "def generator_mux(seed_pool, n_samples, k, lam=256.0, pool_weights=None, with_replacement=True):\n", | |
| " \n", | |
| " n_seeds = len(seed_pool)\n", | |
| " \n", | |
| " # Set up the sampling distribution over streams\n", | |
| " seed_distribution = 1./n_seeds * np.ones(n_seeds)\n", | |
| " \n", | |
| " if pool_weights is None:\n", | |
| " pool_weights = seed_distribution.copy()\n", | |
| " \n", | |
| " assert len(pool_weights) == len(seed_pool)\n", | |
| " assert (pool_weights > 0.0).all()\n", | |
| " pool_weights /= np.sum(pool_weights)\n", | |
| " \n", | |
| " # Instantiate the pool\n", | |
| " streams = []\n", | |
| " \n", | |
| " stream_weights = np.zeros(k)\n", | |
| " \n", | |
| " for i in range(k):\n", | |
| " \n", | |
| " if not (seed_distribution > 0).any():\n", | |
| " break\n", | |
| " \n", | |
| " # how many samples for this stream?\n", | |
| " # pick a stream\n", | |
| " idx = categorical_sample(seed_distribution)\n", | |
| " \n", | |
| " # instantiate\n", | |
| " streams.append(seed_pool[idx].generate(max_items=np.random.poisson(lam=lam)))\n", | |
| " stream_weights[i] = pool_weights[idx]\n", | |
| " \n", | |
| " # If we're sampling without replacement, zero out this one's probability\n", | |
| " if not with_replacement:\n", | |
| " seed_distribution[idx] = 0.0\n", | |
| " if (seed_distribution > 0).any():\n", | |
| " seed_distribution /= np.sum(seed_distribution)\n", | |
| " \n", | |
| " Z = np.sum(stream_weights)\n", | |
| " \n", | |
| " \n", | |
| " # Main sampling loop\n", | |
| " n = 0\n", | |
| " \n", | |
| " while n < n_samples and Z > 0.0:\n", | |
| " # Pick a stream\n", | |
| " idx = categorical_sample(stream_weights / Z)\n", | |
| " \n", | |
| " # Can we sample from it?\n", | |
| " try:\n", | |
| " # Then yield the sample\n", | |
| " yield streams[idx].next()\n", | |
| " \n", | |
| " # Increment the sample counter\n", | |
| " n = n + 1\n", | |
| " \n", | |
| " except StopIteration:\n", | |
| " # Oops, this one's exhausted. Replace it and move on.\n", | |
| " \n", | |
| " # Are there still kids in the pool? Okay.\n", | |
| " if (seed_distribution > 0).any():\n", | |
| " \n", | |
| " new_idx = categorical_sample(pool_weights)\n", | |
| " \n", | |
| " streams[idx] = seed_pool[new_idx].generate(max_items=np.random.poisson(lam=lam))\n", | |
| " stream_weights[idx] = pool_weights[new_idx]\n", | |
| " \n", | |
| " # If we're sampling without replacement, zero out this one's probability and renormalize\n", | |
| " if not with_replacement:\n", | |
| " seed_distribution[new_idx] = 0.0\n", | |
| " \n", | |
| " if (seed_distribution > 0).any():\n", | |
| " seed_distribution /= np.sum(seed_distribution)\n", | |
| " \n", | |
| " else:\n", | |
| " # Otherwise, this one's exhausted. Set its probability to 0 and keep going\n", | |
| " stream_weights[idx] = 0.0\n", | |
| " \n", | |
| " Z = np.sum(stream_weights)" | |
| ], | |
| "language": "python", | |
| "metadata": {}, | |
| "outputs": [], | |
| "prompt_number": 4 | |
| }, | |
| { | |
| "cell_type": "code", | |
| "collapsed": false, | |
| "input": [ | |
| "def stream_fit(estimator, data_sequence, batch_size=100, max_steps=None, **kwargs):\n", | |
| " '''Fit a model to a generator stream.\n", | |
| " \n", | |
| " :parameters:\n", | |
| " - estimator : sklearn.base.BaseEstimator\n", | |
| " The model object. Must implement ``partial_fit()``\n", | |
| " \n", | |
| " - data_sequence : generator\n", | |
| " A generator that yields samples\n", | |
| " \n", | |
| " - batch_size : int\n", | |
| " Maximum number of samples to buffer before updating the model\n", | |
| " \n", | |
| " - max_steps : int or None\n", | |
| " If ``None``, run until the stream is exhausted.\n", | |
| " Otherwise, run until at most ``max_steps`` examples have been processed.\n", | |
| " '''\n", | |
| " \n", | |
| " # Is this a supervised or unsupervised learner?\n", | |
| " supervised = isinstance(estimator, sklearn.base.ClassifierMixin)\n", | |
| " \n", | |
| " # Does the learner support partial fit?\n", | |
| " assert(hasattr(estimator, 'partial_fit'))\n", | |
| " \n", | |
| " def _matrixify(data):\n", | |
| " \"\"\"Determine whether the data is sparse or not, act accordingly\"\"\"\n", | |
| "\n", | |
| " if scipy.sparse.issparse(data[0]):\n", | |
| " n = len(data)\n", | |
| " d = np.prod(data[0].shape)\n", | |
| " \n", | |
| " data_s = scipy.sparse.lil_matrix((n, d), dtype=data[0].dtype)\n", | |
| " \n", | |
| " for i in range(len(data)):\n", | |
| " idx = data[i].indices\n", | |
| " data_s[i, idx] = data[i][:, idx]\n", | |
| "\n", | |
| " return data_s.tocsr()\n", | |
| " else:\n", | |
| " return np.asarray(data)\n", | |
| "\n", | |
| " def _run(data, supervised):\n", | |
| " \"\"\"Wrapper function to partial_fit()\"\"\"\n", | |
| "\n", | |
| " if supervised:\n", | |
| " args = map(_matrixify, zip(*data))\n", | |
| " else:\n", | |
| " args = [_matrixify(data)]\n", | |
| "\n", | |
| " estimator.partial_fit(*args, **kwargs)\n", | |
| " \n", | |
| " buf = []\n", | |
| " for i, x_new in enumerate(data_sequence):\n", | |
| " buf.append(x_new)\n", | |
| " \n", | |
| " # We've run too far, stop\n", | |
| " if max_steps is not None and i > max_steps:\n", | |
| " break\n", | |
| " \n", | |
| " # Buffer is full, do an update\n", | |
| " if len(buf) == batch_size:\n", | |
| " _run(buf, supervised)\n", | |
| " buf = []\n", | |
| " \n", | |
| " # Update on whatever's left over\n", | |
| " if len(buf) > 0:\n", | |
| " _run(buf, supervised)" | |
| ], | |
| "language": "python", | |
| "metadata": {}, | |
| "outputs": [], | |
| "prompt_number": 5 | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "# Demonstration" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "collapsed": false, | |
| "input": [ | |
| "import itertools" | |
| ], | |
| "language": "python", | |
| "metadata": {}, | |
| "outputs": [], | |
| "prompt_number": 6 | |
| }, | |
| { | |
| "cell_type": "code", | |
| "collapsed": false, | |
| "input": [ | |
| "GS = GeneratorSeed(range(50))" | |
| ], | |
| "language": "python", | |
| "metadata": {}, | |
| "outputs": [], | |
| "prompt_number": 7 | |
| }, | |
| { | |
| "cell_type": "code", | |
| "collapsed": false, | |
| "input": [ | |
| "for q in GS.generate(max_items=80):\n", | |
| " print q, " | |
| ], | |
| "language": "python", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "output_type": "stream", | |
| "stream": "stdout", | |
| "text": [ | |
| "0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49\n" | |
| ] | |
| } | |
| ], | |
| "prompt_number": 8 | |
| }, | |
| { | |
| "cell_type": "code", | |
| "collapsed": false, | |
| "input": [ | |
| "def my_generator(*args):\n", | |
| " while True:\n", | |
| " yield args" | |
| ], | |
| "language": "python", | |
| "metadata": {}, | |
| "outputs": [], | |
| "prompt_number": 9 | |
| }, | |
| { | |
| "cell_type": "code", | |
| "collapsed": false, | |
| "input": [ | |
| "# Here's a simple, one-layer generator stream\n", | |
| "for q in generator_mux([GeneratorSeed(my_generator, x) for x in range(30)], \n", | |
| " 500, \n", | |
| " 5, \n", | |
| " lam=10):\n", | |
| " print q, " | |
| ], | |
| "language": "python", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "output_type": "stream", | |
| "stream": "stdout", | |
| "text": [ | |
| "(16,) (16,) (2,) (16,) (28,) (28,) (14,) (28,) (2,) (16,) (27,) (16,) (27,) (2,) (14,) (27,) (27,) (27,) (16,) (2,) (27,) (28,) (27,) (16,) (14,) (14,) (16,) (14,) (2,) (16,) (2,) (28,) (2,) (2,) (16,) (14,) (27,) (28,) (27,) (28,) (14,) (16,) (3,) (28,) (14,) (27,) (16,) (28,) (3,) (16,) (14,) (10,) (27,) (10,) (27,) (14,) (10,) (10,) (27,) (14,) (1,) (3,) (1,) (3,) (3,) (10,) (26,) (1,) (26,) (26,) (26,) (26,) (1,) (1,) (3,) (26,) (1,) (27,) (26,) (26,) (27,) (10,) (27,) (6,) (10,) (6,) (27,) (10,) (26,) (26,) (26,) (6,) (6,) (6,) (3,) (10,) (6,) (3,) (26,) (12,) (6,) (6,) (6,) (12,) (12,) (12,) (12,) (3,) (3,) (23,) (12,) (12,) (12,) (3,) (13,) (13,) (12,) (12,) (23,) (13,) (12,) (23,) (23,) (13,) (24,) (23,) (23,) (12,) (23,) (24,) (13,) (24,) (23,) (23,) (13,) (23,) (13,) (17,) (24,) (23,) (13,) (17,) (24,) (12,) (13,) (17,) (12,) (13,) (17,) (13,) (13,) (15,) (3,) (15,) (12,) (17,) (17,) (23,) (3,) (15,) (23,) (17,) (15,) (12,) (23,) (15,) (15,) (12,) (3,) (3,) (3,) (15,) (3,) (15,) (3,) (15,) (17,) (15,) (17,) (17,) (15,) (15,) (15,) (6,) (15,) (15,) (15,) (15,) (12,) (12,) (15,) (15,) (15,) (15,) (6,) (15,) (29,) (29,) (22,) (15,) (15,) (15,) (29,) (6,) (6,) (29,) (29,) (15,) (15,) (15,) (15,) (22,) (15,) (6,) (29,) (15,) (29,) (15,) (22,) (22,) (15,) (22,) (6,) (29,) (15,) (15,) (15,) (29,) (6,) (15,) (22,) (15,) (15,) (29,) (29,) (15,) (15,) (22,) (28,) (16,) (15,) (29,) (15,) (28,) (16,) (15,) (28,) (16,) (6,) (28,) (15,) (28,) (28,) (11,) (6,) (16,) (28,) (16,) (28,) (11,) (6,) (15,) (16,) (19,) (11,) (3,) (3,) (28,) (28,) (3,) (19,) (19,) (8,) (8,) (3,) (28,) (11,) (8,) (8,) (3,) (19,) (3,) (8,) (11,) (3,) (8,) (11,) (11,) (19,) (3,) (19,) (11,) (8,) (8,) (19,) (8,) (19,) (3,) (26,) (26,) (19,) (3,) (28,) (2,) (2,) (3,) (11,) (19,) (26,) (26,) (19,) (26,) (2,) (19,) (11,) (3,) (26,) (26,) (11,) (11,) (26,) (11,) (26,) (20,) (2,) (26,) (20,) (26,) (26,) (2,) (11,) (26,) (20,) (20,) (20,) (2,) (29,) (11,) (11,) (11,) (11,) (20,) (11,) (11,) (11,) (29,) (11,) (11,) (13,) (11,) (29,) (13,) (11,) (13,) (20,) (11,) (11,) (29,) (13,) (13,) (13,) (20,) (20,) (11,) (13,) (17,) (13,) (13,) (29,) (13,) (29,) (13,) (17,) (29,) (17,) (29,) (29,) (29,) (29,) (29,) (13,) (20,) (29,) (11,) (21,) (21,) (11,) (29,) (17,) (17,) (21,) (29,) (5,) (5,) (21,) (5,) (5,) (29,) (21,) (5,) (11,) (11,) (13,) (21,) (5,) (5,) (11,) (21,) (13,) (13,) (5,) (29,) (11,) (11,) (13,) (10,) (29,) (10,) (13,) (10,) (21,) (11,) (11,) (13,) (13,) (10,) (11,) (21,) (29,) (21,) (21,) (21,) (11,) (10,) (11,) (2,) (11,) (11,) (11,) (2,) (2,) (10,) (2,) (21,) (21,) (21,) (10,) (21,) (2,) (2,) (21,) (21,) (10,) (2,) (11,) (21,) (2,) (0,) (0,) (29,) (11,) (11,) (0,) (0,) (29,) (29,) (29,) (0,) (12,) (0,) (12,) (0,) (12,) (12,) (29,) (17,) (8,) (12,) (0,) (17,) (0,) (0,) (17,) (8,) (0,) (29,) (0,) (8,) (29,) (29,) (29,) (12,) (12,) (0,) (29,) (8,) (17,) (29,) (8,) (29,)\n" | |
| ] | |
| } | |
| ], | |
| "prompt_number": 10 | |
| }, | |
| { | |
| "cell_type": "code", | |
| "collapsed": false, | |
| "input": [ | |
| "# Let's do a two-layer stream\n", | |
| "muxen = []\n", | |
| "\n", | |
| "for x in 'ABCDE':\n", | |
| " muxen.append(GeneratorSeed(generator_mux, \n", | |
| " [GeneratorSeed(my_generator, x, y) for y in range(10)], \n", | |
| " 100, \n", | |
| " 10, \n", | |
| " lam=20))\n", | |
| " \n", | |
| " \n", | |
| "for q in generator_mux(muxen, 200, 2, lam=20, with_replacement=False):\n", | |
| " print q," | |
| ], | |
| "language": "python", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "output_type": "stream", | |
| "stream": "stdout", | |
| "text": [ | |
| "('A', 7) ('A', 5) ('E', 7) ('E', 1) ('E', 8) ('A', 3) ('A', 3) ('A', 5) ('E', 2) ('E', 1) ('E', 4) ('E', 2) ('A', 7) ('A', 2) ('A', 8) ('A', 4) ('A', 8) ('A', 3) ('E', 4) ('A', 2) ('E', 2) ('E', 4) ('A', 8) ('E', 0) ('E', 2) ('A', 7) ('E', 8) ('A', 2) ('A', 7) ('E', 8) ('E', 4) ('A', 2) ('A', 3) ('A', 8) ('E', 1) ('E', 3) ('E', 4) ('A', 5) ('E', 3) ('E', 1) ('A', 4) ('E', 1) ('A', 5) ('E', 4) ('E', 4) ('C', 8) ('C', 6) ('A', 4) ('A', 8) ('A', 3) ('A', 3) ('C', 5) ('C', 2) ('C', 6) ('C', 2) ('C', 6) ('A', 2) ('C', 6) ('A', 4) ('C', 7) ('A', 7) ('A', 0) ('A', 1) ('A', 7) ('C', 5) ('A', 6) ('A', 9) ('C', 0) ('C', 2) ('C', 0) ('C', 8) ('A', 6) ('C', 6) ('A', 4) ('C', 8) ('C', 0) ('C', 8) ('C', 8) ('C', 2) ('A', 4) ('A', 1) ('C', 5) ('A', 7) ('A', 0) ('C', 2) ('A', 0) ('C', 5) ('A', 0) ('A', 6) ('A', 1) ('A', 1) ('A', 1) ('A', 3) ('A', 1) ('A', 1) ('A', 6) ('A', 1) ('A', 0) ('A', 7) ('A', 4) ('E', 8) ('E', 9) ('A', 7) ('A', 5) ('A', 8) ('E', 4) ('E', 5) ('E', 9) ('A', 5) ('A', 6) ('A', 5) ('E', 7) ('E', 5) ('E', 7) ('A', 7) ('A', 1) ('E', 0) ('A', 1) ('A', 4) ('A', 1) ('E', 7) ('E', 7) ('E', 5) ('A', 1) ('A', 8) ('A', 6) ('D', 6) ('D', 0) ('A', 7) ('A', 1) ('A', 1) ('A', 6) ('A', 3) ('A', 1) ('D', 0) ('C', 8) ('C', 3) ('D', 6) ('D', 4) ('C', 8) ('C', 4) ('D', 7) ('C', 2) ('C', 6) ('D', 7) ('C', 3) ('D', 7) ('C', 8) ('C', 4) ('C', 5) ('D', 4) ('D', 4) ('C', 8) ('D', 5) ('C', 2) ('D', 4) ('C', 2) ('C', 8) ('C', 3) ('D', 0) ('C', 4) ('C', 6) ('C', 9) ('C', 8) ('C', 2) ('C', 8) ('D', 5) ('D', 0) ('C', 9) ('C', 2) ('D', 5) ('C', 5) ('C', 8) ('B', 9) ('D', 0) ('D', 7) ('B', 3) ('D', 4) ('D', 0) ('D', 4) ('B', 8) ('B', 9) ('B', 8) ('B', 2) ('D', 4) ('D', 4) ('D', 7) ('B', 0) ('D', 4) ('B', 8) ('B', 9) ('B', 9) ('B', 7) ('B', 2) ('B', 9) ('B', 9) ('B', 7) ('B', 4) ('B', 2) ('B', 9)\n" | |
| ] | |
| } | |
| ], | |
| "prompt_number": 18 | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "# learn_ooc, sklearn example" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "collapsed": false, | |
| "input": [ | |
| "import sklearn\n", | |
| "import sklearn.datasets\n", | |
| "import sklearn.linear_model\n", | |
| "import sklearn.cross_validation\n", | |
| "import sklearn.metrics\n", | |
| "import sklearn.grid_search\n", | |
| "\n", | |
| "import learn_ooc" | |
| ], | |
| "language": "python", | |
| "metadata": {}, | |
| "outputs": [], | |
| "prompt_number": 19 | |
| }, | |
| { | |
| "cell_type": "code", | |
| "collapsed": false, | |
| "input": [ | |
| "def data_generator(X, Y, target_class, scale = 1e-1):\n", | |
| " \n", | |
| " idx = np.flatnonzero(Y == target_class)\n", | |
| " \n", | |
| " X = X[idx]\n", | |
| " Y = Y[idx]\n", | |
| " \n", | |
| " n, d = X.shape\n", | |
| " \n", | |
| " while True:\n", | |
| " i = np.random.randint(0, n)\n", | |
| " noise = scale * np.random.randn(d)\n", | |
| " yield X[i] + noise, Y[i]" | |
| ], | |
| "language": "python", | |
| "metadata": {}, | |
| "outputs": [], | |
| "prompt_number": 20 | |
| }, | |
| { | |
| "cell_type": "code", | |
| "collapsed": false, | |
| "input": [ | |
| "data = sklearn.datasets.load_iris()\n", | |
| "X = data.data\n", | |
| "Y = data.target" | |
| ], | |
| "language": "python", | |
| "metadata": {}, | |
| "outputs": [], | |
| "prompt_number": 21 | |
| }, | |
| { | |
| "cell_type": "code", | |
| "collapsed": false, | |
| "input": [ | |
| "# Make a train-test split\n", | |
| "for train, test in sklearn.cross_validation.ShuffleSplit(len(X), n_iter=5, test_size=0.2):\n", | |
| " \n", | |
| " # Make the streams\n", | |
| " seeds = [GeneratorSeed(data_generator, X[train], Y[train], z) for z in range(3)]\n", | |
| " \n", | |
| " # Make the mux\n", | |
| " mux_stream = generator_mux(seeds, 5e5, k=3, lam=1e4, with_replacement=False)\n", | |
| " \n", | |
| " # Make a model\n", | |
| " CF = sklearn.linear_model.SGDClassifier(verbose=0, penalty='l1', n_jobs=3)\n", | |
| " \n", | |
| " stream_fit(CF, mux_stream, batch_size=512, classes=range(3))\n", | |
| " \n", | |
| " print 'Accuracy: %.3f' % sklearn.metrics.accuracy_score(Y[test], CF.predict(X[test]))" | |
| ], | |
| "language": "python", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "output_type": "stream", | |
| "stream": "stdout", | |
| "text": [ | |
| "Accuracy: 0.800\n" | |
| ] | |
| }, | |
| { | |
| "output_type": "stream", | |
| "stream": "stdout", | |
| "text": [ | |
| "Accuracy: 0.933\n" | |
| ] | |
| }, | |
| { | |
| "output_type": "stream", | |
| "stream": "stdout", | |
| "text": [ | |
| "Accuracy: 0.900\n" | |
| ] | |
| }, | |
| { | |
| "output_type": "stream", | |
| "stream": "stdout", | |
| "text": [ | |
| "Accuracy: 1.000\n" | |
| ] | |
| }, | |
| { | |
| "output_type": "stream", | |
| "stream": "stdout", | |
| "text": [ | |
| "Accuracy: 0.900\n" | |
| ] | |
| } | |
| ], | |
| "prompt_number": 22 | |
| } | |
| ], | |
| "metadata": {} | |
| } | |
| ] | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment