Skip to content

Instantly share code, notes, and snippets.

@harperjiang
Last active June 21, 2019 08:28
Show Gist options
  • Save harperjiang/363b7250c15ab726d48e8cde40c23544 to your computer and use it in GitHub Desktop.
Save harperjiang/363b7250c15ab726d48e8cde40c23544 to your computer and use it in GitHub Desktop.
Performance comparison of numpy vs nd4j on LSTM implementation

My task is to look for new architecture of LSTM that can improve its performance in text prediction, which is, given previous character sequences, predict the next character. For this purpose, I need to handwrite new LSTM implementation. I already have a framework in numpy, of which I would like to migrate to Nd4j.

I have the following hyperparamters:

  • numChar: Number of distinct characters
  • hiddenDim: The size of hidden dimension in my LSTM cell
  • batchSize: batch size

I have the following parameters:

  • c2v: a character embedding matrix, shape [numChar, hiddenDim]
  • v2c: a mapping matrix from hidden dimension to char, shape [hiddenDim, numChar]
  • w1-w4: network, shape[2*hiddenDim, hiddenDim]
  • b1-b4: bias, shape[hiddenDim]
  • h0: Init hidden state, shape [batchSize, hiddenDim]
  • c0: Init cell state, shape [batchSize, hiddenDim]

My simple LSTM cell follows the following steps:

  • input is a vector of shape [batchSize, 1], each element is an index between 0 and numChar
  • fetch the embedding from c2v using the index, get a matrix embedded of shape [batchSize, hiddenDim]
  • concatenate embedded with h0, get an matrix concat of shape [batchSize, 2* hiddenDim]
  • create a forget_gate = sigmoid(concat * w1 + b1)
  • create an info_gate = tanh(concat * w2 + b2) * sigmoid(concat * w3 + b3)
  • update cell state: c_i+1 = c_i * forget_gate + info_gate
  • update hidden state: h_i+1 = tanh(c_i+1) * sigmoid(concat * w4 + b4)

The Nd4j implementation is much slower than numpy. In the attached source code, I showed the cumulative result of first 4 steps of computation The time is in secs. Besides the weird concat operation, other nd4j operations are all at least 5-6 times slower than their numpy counterpart.

Embed Concat F Gate I Gate
Numpy 0.01 0.018 0.332 0.46
Nd4j 0.409 6.056 6.846 7.449

I am running numpy 1.11.2 compiled with Intel MKL and Openblas on Python 3.5.2, Ubuntu 16.10. Nd4j version is 0.7.2 with JDK 1.8.0_111

import org.nd4j.linalg.api.ndarray.INDArray
import org.nd4j.linalg.api.ops.impl.broadcast.BroadcastAddOp
import org.nd4j.linalg.api.rng.distribution.impl.UniformDistribution
import org.nd4j.linalg.factory.Nd4j
import scala.util.Random
object Xavier {
def init(shape: Array[Int]): INDArray = {
var n = shape.dropRight(1).product
var sd = Math.sqrt(3d / n)
new UniformDistribution(-sd,sd).sample(shape)
}
}
object LSTM extends App {
val hiddenDim = 200
val numChar = 100
val c2v = Xavier.init(Array(numChar, hiddenDim))
val v2c = Xavier.init(Array(hiddenDim, numChar))
val inputSize = 2 * hiddenDim
val w1 = Xavier.init(Array(inputSize, hiddenDim))
val b1 = Nd4j.zeros(1, hiddenDim)
val w2 = Xavier.init(Array(inputSize, hiddenDim))
val b2 = Nd4j.zeros(1, hiddenDim)
val w3 = Xavier.init(Array(inputSize, hiddenDim))
val b3 = Nd4j.zeros(1, hiddenDim)
val w4 = Xavier.init(Array(inputSize, hiddenDim))
val b4 = Nd4j.zeros(1, hiddenDim)
// Random Batch
val length = 500
val batchSize = 50
val batch = (0 until length).map(i=> {
Seq.fill(batchSize)(Random.nextInt(numChar)).toArray
})
val h0 = Nd4j.zeros(batchSize, hiddenDim)
val c0 = Nd4j.zeros(batchSize, hiddenDim)
val start = System.currentTimeMillis()
val fwdmap = Nd4j.zeros(batchSize, numChar)
batch.foreach{
item => {
fwdmap.assign(0)
item.zipWithIndex.foreach(p => fwdmap.putScalar(Array(p._2, p._1), 1))
val embedded = fwdmap.mmul(c2v)
//val embedded = c2v.get(new SpecifiedIndex(item:_*),NDArrayIndex.all())
val concat = Nd4j.concat(1, embedded, h0)
val fgate = sigmoid(add(concat.mmul(w1), b1))
val igate = sigmoid(add(concat.mmul(w2), b2))
}
}
println((System.currentTimeMillis() - start).toDouble/1000)
def add(a:INDArray, b:INDArray) = {
Nd4j.getExecutioner.execAndReturn(new BroadcastAddOp(a,b,a,1))
}
def sigmoid(input:INDArray) =
Nd4j.getExecutioner().execAndReturn(
new org.nd4j.linalg.api.ops.impl.transforms.Sigmoid(input))
}
import numpy as np
from time import time
hidden_dim = 200
batch_size = 50
num_char = 100
stime = time()
def xavier(shape):
sq = np.sqrt(3.0 / np.prod(shape[:-1]))
return np.random.uniform(-sq, sq, shape)
C2V = xavier([num_char, hidden_dim])
w1 = xavier([2 * hidden_dim, hidden_dim])
b1 = np.zeros([hidden_dim])
w2 = xavier([2 * hidden_dim, hidden_dim])
b2 = np.zeros([hidden_dim])
w3 = xavier([2 * hidden_dim, hidden_dim])
b3 = np.zeros([hidden_dim])
w4 = xavier([2 * hidden_dim, hidden_dim])
b4 = np.zeros([hidden_dim])
V2C = xavier([hidden_dim, num_char])
def sigmoid(input):
return 1. / (1. + np.exp(-input))
def tanh(input):
x_exp = np.exp(input)
x_neg_exp = np.exp(-input)
return (x_exp - x_neg_exp) / (x_exp + x_neg_exp)
# Generate random batch
length = 500
batch = np.random.randint(num_char, size = (batch_size, length))
h0 = np.zeros([batch_size, hidden_dim])
c0 = np.zeros([batch_size, hidden_dim])
for i in range(batch.shape[1]):
item = batch[:, i]
embed = C2V[np.int32(item),:]
concat = np.concatenate((embed, h0), axis = 1)
fgate = sigmoid(np.matmul(concat, w1) + b1)
igate = sigmoid(np.matmul(concat, w2) + b2)
t = time() - stime
print(t)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment