Skip to content

Instantly share code, notes, and snippets.

@ozancaglayan
Last active December 13, 2015 18:06
Show Gist options
  • Save ozancaglayan/ebe2fa5ba420622ed5e1 to your computer and use it in GitHub Desktop.
Save ozancaglayan/ebe2fa5ba420622ed5e1 to your computer and use it in GitHub Desktop.
#!/usr/bin/env python
import sys
import numpy as np
import time
import theano
import theano.tensor as T
np.random.seed(1)
n_samples = int(sys.argv[1])
seq_steps = 32
vocab_size = 20000
def numpy_loss(preds_, labels_):
true_idxs = np.arange(labels_.size) * preds_.shape[2] + labels_.flatten()
return -np.log(preds_.flatten()[true_idxs]).mean()
def theano_loss(preds_, labels_):
true_idxs = T.cast(T.arange(labels_.size) * preds_.shape[2] + labels_.flatten(), dtype="int32")
return -T.log(preds_.flatten()[true_idxs]).mean()
def naive_loss(preds_, labels_):
loss = 0.0
batch_size = labels_.shape[0]
seq_len = labels_.shape[1]
for i in xrange(batch_size):
for j in xrange(seq_len):
loss += np.log(preds[i, j, labels[i, j]])
return -loss / (seq_len * batch_size)
# Predictions
t = time.time()
preds = np.random.uniform(size=(n_samples, seq_steps, vocab_size)).astype(np.float32)
print "preds array took %.5f seconds to create" % (time.time() - t)
# Labels
t = time.time()
labels = np.random.randint(0, vocab_size, size=(n_samples, \
seq_steps)).astype(np.int32)
print "labels array took %.5f seconds to create" % (time.time() - t)
t = time.time()
loss_1 = naive_loss(preds, labels)
print "naive_loss took %.5f seconds to compute: %5.5f" % ((time.time() - t), loss_1)
t = time.time()
loss_3 = numpy_loss(preds, labels)
print "numpy_loss took %.5f seconds to compute: %5.5f" % ((time.time() - t), loss_3)
preds_ = T.tensor3("preds", dtype=theano.config.floatX)
labels_= T.matrix("labels", dtype="int32")
loss_2_fun = theano.function(inputs=[preds_, labels_],
outputs=[theano_loss(preds_, labels_)])
t = time.time()
loss_2 = loss_2_fun(preds, labels)
print "theano_loss took %.5f seconds to compute: %5.5f" % ((time.time() - t), loss_2[0])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment