Skip to content

Instantly share code, notes, and snippets.

@dwf
Created October 12, 2016 04:37
Show Gist options
  • Save dwf/2279979ef4ddcd5ebd8db96f65569421 to your computer and use it in GitHub Desktop.
Save dwf/2279979ef4ddcd5ebd8db96f65569421 to your computer and use it in GitHub Desktop.
Example of interleaving data streams with Fuel and conditional processing with Blocks.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Example: Dispatching multiple algorithms on an interleaved data stream"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**Objective:** we'd like to have a `DataStream` that alternatingly returns batches from two (or more) different underlying `DataStream`s. The problem is we'd like a way to *identify* which stream a given batch came from, information that we'd like to use to *process them differently* during training."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"First, we'll create an interleaving DataStream that lets you alternate batches from several different streams. We won't bother implementing `get_data` because we only care about epoch iterators. We'll use the `interleave` function from `picklable_itertools.extras` (which is reimplemented from the version in [Toolz](http://toolz.readthedocs.org/))."
]
},
{
"cell_type": "code",
"execution_count": 24,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"from picklable_itertools.extras import interleave\n",
"from fuel.streams import AbstractDataStream\n",
"\n",
"class Interleave(AbstractDataStream):\n",
" def __init__(self, streams, **kwargs):\n",
" self.streams = streams\n",
" super(Interleave, self).__init__(**kwargs)\n",
"\n",
" def get_data(self):\n",
" # Nothing sensible can be easily done here\n",
" raise NotImplementedError\n",
"\n",
" def next_epoch(self):\n",
" for stream in self.streams:\n",
" stream.next_epoch()\n",
"\n",
" def close(self):\n",
" for stream in self.streams:\n",
" stream.close()\n",
"\n",
" def reset(self):\n",
" for stream in self.streams:\n",
" stream.reset()\n",
"\n",
" def get_epoch_iterator(self, as_dict=False):\n",
" return interleave([s.get_epoch_iterator(as_dict=as_dict) for s in self.streams])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We'd like our algorithm to take different actions based on which original stream the batch came from. Rather than complicate the `Interleave` stream more we can instead create a transformer that adds a source containing an arbitrary token in every batch. We'll use this token to identify which original stream the batch came from."
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"from fuel.transformers import Transformer\n",
"\n",
"\n",
"class IdentifierSource(Transformer):\n",
" \"\"\"Add a source to the DataStream that is just a repeating stream identifier.\"\"\"\n",
" def __init__(self, stream, new_source_name, id_value, **kwargs):\n",
" self.id_source_name = new_source_name\n",
" self.id_value = id_value\n",
" super(IdentifierSource, self).__init__(stream, **kwargs)\n",
"\n",
" @property\n",
" def produces_examples(self):\n",
" return self.data_stream.produces_examples\n",
"\n",
" @property\n",
" def sources(self):\n",
" return super(IdentifierSource, self).sources + (self.id_source_name,)\n",
"\n",
" def transform_batch(self, batch):\n",
" return batch + (self.id_value,)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Finally we'll add a simple `TrainingAlgorithm` to dispatch a batch to one of several delegate `TrainingAlgorithm`s based on such an identifier source. It takes a name for the identifier source to check, and a dictionary mapping identifier tokens that might appear in that source to `TrainingAlgorithm`s."
]
},
{
"cell_type": "code",
"execution_count": 26,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"from blocks.algorithms import TrainingAlgorithm\n",
"\n",
"\n",
"class SourceRoutedTrainingAlgorithm(TrainingAlgorithm):\n",
" def __init__(self, identifier_source, algorithm_dict, **kwargs):\n",
" self.identifier_source = identifier_source\n",
" self.algorithm_dict = algorithm_dict\n",
" super(SourceRoutedTrainingAlgorithm, self).__init__(**kwargs)\n",
" \n",
" def initialize(self):\n",
" for key in sorted(self.algorithm_dict): # sorted for consistent execution order\n",
" self.algorithm_dict[key].initialize()\n",
"\n",
" def process_batch(self, batch):\n",
" # .pop() because we don't want/need the delegate algorithm to get the identifier token.\n",
" identifier = batch.pop(self.identifier_source)\n",
" return self.algorithm_dict[identifier].process_batch(batch)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now, let's assemble a stream that alternates batches of MNIST and CIFAR10."
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"from fuel.datasets import MNIST, CIFAR10\n",
"from fuel.streams import DataStream\n",
"from fuel.schemes import SequentialScheme\n",
"\n",
"\n",
"def make_stream(dataset, batch_size):\n",
" scheme = SequentialScheme(batch_size=batch_size, examples=dataset.num_examples)\n",
" return DataStream.default_stream(dataset, iteration_scheme=scheme)\n",
"\n",
"mnist_stream = make_stream(MNIST(('train',)), 10)\n",
"cifar_stream = make_stream(CIFAR10(('train',)), 20)\n",
"\n",
"# The 'which_stream' source will contain either the string 'mnist' or 'cifar'.\n",
"interleaved = Interleave([IdentifierSource(mnist_stream, 'which_stream', 'mnist'),\n",
" IdentifierSource(cifar_stream, 'which_stream', 'cifar')])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Testing it out: we'll create a dummy algorithm class that just prints some stuff. Ordinarily you'd probably use `GradientDescent` or something like that."
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"class DummyAlgorithm(TrainingAlgorithm):\n",
" def __init__(self, number, **kwargs):\n",
" self.number = number\n",
" super().__init__(**kwargs)\n",
" \n",
" def initialize(self):\n",
" print('Algorithm %d initialized' % self.number)\n",
" \n",
" def process_batch(self, batch):\n",
" print('Training algorithm %d reporting' % self.number)\n",
" print('Got minibatch with sources:', list(batch.keys()))\n",
" # Print out the shape so we can visually verify whether it's a MNIST\n",
" # or CIFAR batch. Recall that MNIST is 28x28 and CIFAR10 is 32x32,\n",
" # and the base streams have different batch sizes (10 and 20).\n",
" print('features.shape =', batch['features'].shape)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now we'll create a `SourceRoutedTrainingAlgorithm` that routes to two different `DummyAlgorithms` depending on whether the batch is an MNIST batch or a CIFAR batch."
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Algorithm 2 initialized\n",
"Algorithm 1 initialized\n",
"Training algorithm 1 reporting\n",
"Got minibatch with sources: ['targets', 'features']\n",
"features.shape = (10, 1, 28, 28)\n",
"Training algorithm 2 reporting\n",
"Got minibatch with sources: ['targets', 'features']\n",
"features.shape = (20, 3, 32, 32)\n",
"Training algorithm 1 reporting\n",
"Got minibatch with sources: ['targets', 'features']\n",
"features.shape = (10, 1, 28, 28)\n",
"Training algorithm 2 reporting\n",
"Got minibatch with sources: ['targets', 'features']\n",
"features.shape = (20, 3, 32, 32)\n",
"Training algorithm 1 reporting\n",
"Got minibatch with sources: ['targets', 'features']\n",
"features.shape = (10, 1, 28, 28)\n",
"Training algorithm 2 reporting\n",
"Got minibatch with sources: ['targets', 'features']\n",
"features.shape = (20, 3, 32, 32)\n",
"Training algorithm 1 reporting\n",
"Got minibatch with sources: ['targets', 'features']\n",
"features.shape = (10, 1, 28, 28)\n",
"Training algorithm 2 reporting\n",
"Got minibatch with sources: ['targets', 'features']\n",
"features.shape = (20, 3, 32, 32)\n",
"Training algorithm 1 reporting\n",
"Got minibatch with sources: ['targets', 'features']\n",
"features.shape = (10, 1, 28, 28)\n",
"Training algorithm 2 reporting\n",
"Got minibatch with sources: ['targets', 'features']\n",
"features.shape = (20, 3, 32, 32)\n"
]
}
],
"source": [
"from blocks.main_loop import MainLoop\n",
"from blocks.extensions import FinishAfter\n",
"\n",
"algorithm = SourceRoutedTrainingAlgorithm('which_stream',\n",
" {k: DummyAlgorithm(n + 1)\n",
" for n, k in enumerate(['mnist', 'cifar'])})\n",
"\n",
"MainLoop(data_stream=interleaved, algorithm=algorithm,\n",
" extensions=[FinishAfter(after_n_batches=10)]).run()"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.4.5"
}
},
"nbformat": 4,
"nbformat_minor": 0
}
@apeterswu
Copy link

Great post. This would be great help.

@apeterswu
Copy link

apeterswu commented Oct 24, 2016

I met one problem while using this:

return super(IdentifierSource, self).sources + (self.id_source_name,)
TypeError: can only concatenate list (not "tuple") to list.

Do I need to change like:

return super(IdentifierSource, self).sources + [self.id_source_name, ]

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment