Last active
December 20, 2016 00:45
-
-
Save dwf/f75401b7a8e14e13bbf708da001b1035 to your computer and use it in GitHub Desktop.
Minimal examples of using the Blocks MainLoop and checkpointing machinery.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# It's important not to define any classes you want serialized in | |
# the script you're running as pickle doesn't like that (if you pass | |
# save_main_loop=False to Checkpoint it's fine, though). | |
from theano import tensor | |
import numpy | |
import theano | |
from picklable_itertools import imap, izip, repeat | |
# Your algorithm object just needs two required methods: initialize() | |
# and process_batch(). | |
# | |
# Note that in many, many cases the GradientDescent object defined | |
# in blocks.algorithms (and the associated StepRule objects, like | |
# Momentum and Adam) are sufficient. Just pass the cost, parameters | |
# (or gradients, if you prefer) and step rule. | |
# | |
class MyAlgorithm(object): | |
def __init__(self, cost, X, y, parameters, learning_rate=0.01): | |
gradients = tensor.grad(cost, wrt=parameters) | |
updates = [(p, p - learning_rate * g) | |
for p, g in zip(parameters, gradients)] | |
# Returns the cost and then updates the parameters. | |
self.func = theano.function([X, y], cost, updates=updates) | |
def initialize(self): | |
# Do whatever you want here, Blocks will call it before | |
# processing any batches. For example you might decide to | |
# compile the function here instaed of in the constructor. | |
pass | |
def process_batch(self, batch): | |
# batch is a dictionary mapping names to data (e.g. numpy arrays). | |
print('cost: ', self.func(batch['X'], batch['y'])) | |
# A DataStream just needs one method: get_epoch_iterator. It should | |
# return an iterator object that. | |
# | |
# NOTE: Checkpointing requires these iterators to be serializable by | |
# Python's pickle module ("picklable"). Many common iterators are not | |
# picklable on Python 2 (but are on Python 3). So Python 3 will make your | |
# life easier in this regard; otherwise you may find the package | |
# picklable-itertools <http://github.com/mila-udem/picklable-itertools> | |
# helpful. | |
# | |
# If you use Fuel to make your stream then this is all taken care of | |
# for you with lots of common datasets. | |
class MyDataStream(object): | |
def __init__(self, dim, offset=5, batch_size=20, batches=50): | |
self.dim = dim | |
self.offset = offset | |
self.batches = batches | |
self.batch_size = batch_size | |
def get_epoch_iterator(self, as_dict=False): | |
# The goal of this "dataset" is to predict the mean of the vector | |
# plus a constant. We'll give back an iterator of 10 batches of | |
# 100 examples, 50 features. | |
X = numpy.random.normal(size=(self.batches, self.batch_size, self.dim)) | |
y = X.mean(axis=2) + self.offset | |
if as_dict: | |
# Return an iterator of dictionaries of the form | |
# {'X': features, 'y': targets} | |
# This is the equivalent of the more idiomatic | |
# | |
# iter([{'X': x, 'y'} for x, y in zip(X, y)]) | |
# | |
# but this won't be picklable on Python 2, at least. | |
return imap(dict, izip(izip(repeat('X'), X), izip(repeat('y'), y))) | |
else: | |
return izip(X, y) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from lib import MyAlgorithm, MyDataStream | |
from blocks.main_loop import MainLoop | |
from blocks.extensions.saveload import Checkpoint | |
from blocks.extensions import Printing | |
import numpy | |
import theano | |
from theano import tensor | |
if __name__ == "__main__": | |
dim = 25 | |
# Just some Theano stuff. Nothing Blocks-y. | |
X = tensor.matrix('X') | |
y = tensor.vector('y') | |
W = theano.shared(numpy.zeros((dim, 1)), name='W') | |
b = theano.shared(numpy.zeros(1,), name='b') | |
predicted = tensor.dot(X, W) + b | |
cost = tensor.sqr(predicted - y.dimshuffle(0, 'x')).mean() | |
# Build a checkpointer. You specify how often with keywords, | |
# e.g. after_epoch, every_n_batches, every_n_epochs, etc. | |
# | |
# This checkpoints the whole MainLoop by default, or just | |
# the parameters if you pass in a parameters list and also | |
# say save_main_loop = False. | |
# Passing the parameters as a keyword argument lets you retrieve | |
# the parameters separately from the checkpoint with | |
# | |
# >>> from blocks.serialization import load_parameters | |
# >>> with open('mycheckpoint.tar', 'rb') as f: | |
# ... params = load_parameters(f) | |
checkpointer = Checkpoint('mycheckpoint.tar', every_n_epochs=500, | |
save_main_loop=False, parameters=[W, b]) | |
# This just prints some handy stuff after every epoch by default. | |
# Configurable with the same frequency arguments as the checkpointer. | |
printer = Printing() | |
main_loop = MainLoop(algorithm=MyAlgorithm(cost, X, y, [W, b]), | |
data_stream=MyDataStream(dim), | |
extensions=[checkpointer, printer]) | |
main_loop.run() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Same as minimal_blocks.py but demonstrates the GradientDescent object. | |
from lib import MyDataStream | |
from blocks.algorithms import GradientDescent, Scale | |
from blocks.main_loop import MainLoop | |
from blocks.extensions.saveload import Checkpoint | |
from blocks.extensions import Printing | |
from blocks.extensions.monitoring import TrainingDataMonitoring | |
import numpy | |
import theano | |
from theano import tensor | |
if __name__ == "__main__": | |
dim = 25 | |
# Just some Theano stuff. Nothing Blocks-y. | |
X = tensor.matrix('X') | |
y = tensor.vector('y') | |
W = theano.shared(numpy.zeros((dim, 1)), name='W') | |
b = theano.shared(numpy.zeros(1,), name='b') | |
predicted = tensor.dot(X, W) + b | |
cost = tensor.sqr(predicted - y.dimshuffle(0, 'x')).mean() | |
# Build a checkpointer. You specify how often with keywords, | |
# e.g. after_epoch, every_n_batches, every_n_epochs, etc. | |
# | |
# This checkpoints the whole MainLoop by default, or just | |
# the parameters if you pass in a parameters list and also | |
# say save_main_loop = False. | |
# Passing the parameters as a keyword argument lets you retrieve | |
# the parameters separately from the checkpoint with | |
# | |
# >>> from blocks.serialization import load_parameters | |
# >>> with open('mycheckpoint.tar', 'rb') as f: | |
# ... params = load_parameters(f) | |
checkpointer = Checkpoint('mycheckpoint.tar', every_n_epochs=500, | |
save_main_loop=False, parameters=[W, b]) | |
# This just prints some handy stuff after every epoch by default. | |
# Configurable with the same frequency arguments as the checkpointer. | |
printer = Printing() | |
# Build a GradientDescent to do simple SGD. The Scale(0.01) says | |
# we just want to scale the gradients by a fixed learning rate. | |
# See also Adam(), AdaDelta(), Momentum(), etc. | |
# | |
# If you want to do something weird for the gradients just pass in | |
# gradients=my_gradients instead of (or in addition to) the parameters. | |
# By default the algorithm prints nothing but you can get it to | |
# (efficiently) keep tabs on by adding a monitoring extension. | |
# By default this just tells you the last batch's value; you can look | |
# in blocks-examples for how to aggregate over the training set | |
# as you process it (blocks.monitoring aggregation.mean(cost) | |
# or something like that). | |
# | |
# See also DataStreamMonitoring for monitoring on a validation set. | |
# Note that you'll want to give the cost variable an informative | |
# name as that's what gets used in the MainLoop's log. | |
# (Important to put it before the Printing extension in the extensions | |
# list, so that when Printing runs, the log entry is already there.) | |
cost.name = 'my_training_objective' | |
monitoring = TrainingDataMonitoring([cost], after_epoch=True) | |
algorithm = GradientDescent(cost=cost, parameters=[W, b], | |
step_rule=Scale(0.01)) | |
main_loop = MainLoop(algorithm=algorithm, | |
data_stream=MyDataStream(dim), | |
extensions=[monitoring, checkpointer, printer]) | |
main_loop.run() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment