Created
February 9, 2014 01:00
-
-
Save kespindler/8892664 to your computer and use it in GitHub Desktop.
Benchmark EC2 GPU instances with Pylearn2 deep learning on MNIST dataset
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
sudo apt-get -y update | |
sudo apt-get -y upgrade | |
sudo apt-get -y dist-upgrade | |
sudo apt-get -y install git make python-dev python-setuptools libblas-dev gfortran g++ python-pip python-numpy python-scipy liblapack-dev | |
sudo pip install ipython nose | |
sudo apt-get install screen | |
sudo pip install --upgrade git+git://github.com/Theano/Theano.git | |
wget http://developer.download.nvidia.com/compute/cuda/repos/ubuntu1204/x86_64/cuda-repo-ubuntu1204_5.5-0_amd64.deb | |
sudo dpkg -i cuda-repo-ubuntu1204_5.5-0_amd64.deb | |
sudo apt-get update | |
sudo apt-get install cuda | |
#THEANO_FLAGS=floatX=float32,device=gpu0 python /usr/local/lib/python2.7/dist-packages/theano/misc/check_blas.py | |
git clone git://github.com/lisa-lab/pylearn2.git | |
cd pylearn2 | |
sudo python setup.py install | |
cd .. | |
echo "export PATH=/usr/local/cuda-5.5/bin:$PATH" >> .bashrc | |
echo "export LD_LIBRARY_PATH=/usr/local/cuda-5.5/lib64:$LD_LIBRARY_PATH" >> .bashrc | |
echo "export PYLEARN2_DATA_PATH=/home/ubuntu/data" >> .bashrc | |
source .bashrc | |
mkdir -p data/mnist/ | |
cd data/mnist/ | |
wget http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz | |
gunzip train-images-idx3-ubyte.gz | |
wget http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz | |
gunzip train-labels-idx1-ubyte.gz | |
wget http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz | |
gunzip t10k-images-idx3-ubyte.gz | |
wget http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz | |
gunzip t10k-labels-idx1-ubyte.gz | |
cd ../.. | |
echo '[global] | |
floatX = float32 | |
device = gpu0 | |
[nvcc] | |
fastmath = True' > .theanorc |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from __future__ import division | |
from pylearn2.train import Train | |
from pylearn2.datasets.mnist import MNIST | |
from pylearn2.models import softmax_regression, mlp | |
from pylearn2.training_algorithms import bgd | |
from pylearn2.termination_criteria import MonitorBased | |
from pylearn2.train_extensions import best_params | |
from pylearn2.utils import serial | |
from theano import function | |
from theano import tensor as T | |
import numpy as np | |
import os | |
h0 = mlp.Sigmoid(layer_name='h0', dim=500, sparse_init=15) | |
ylayer = mlp.Softmax(layer_name='y', n_classes=10, irange=0) | |
layers = [h0, ylayer] | |
model = mlp.MLP(layers, nvis=784) | |
train = MNIST('train', one_hot=1, start=0, stop=50000) | |
valid = MNIST('train', one_hot=1, start=50000, stop=60000) | |
test = MNIST('test', one_hot=1, start=0, stop=10000) | |
monitoring = dict(valid=valid) | |
termination = MonitorBased(channel_name="valid_y_misclass") | |
extensions = [best_params.MonitorBasedSaveBest(channel_name="valid_y_misclass", | |
save_path="train_best.pkl")] | |
algorithm = bgd.BGD(batch_size=10000, line_search_mode = 'exhaustive', conjugate = 1, | |
monitoring_dataset = monitoring, termination_criterion = termination) | |
save_path = "train_best.pkl" | |
if os.path.exists(save_path): | |
model = serial.load(save_path) | |
else: | |
print 'Running training' | |
train_job = Train(train, model, algorithm, extensions=extensions, save_path="train.pkl", save_freq=1) | |
train_job.main_loop() | |
X = model.get_input_space().make_batch_theano() | |
Y = model.fprop(X) | |
y = T.argmax(Y, axis=1) | |
f = function([X], y) | |
yhat = f(test.X) | |
y = np.where(test.get_targets())[1] | |
print 'accuracy', (y==yhat).sum() / y.size |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment