Skip to content

Instantly share code, notes, and snippets.

@oarriaga
Forked from Qwlouse/lstm_reference.ipynb
Created August 14, 2017 14:16
Show Gist options
  • Save oarriaga/f070d9bd0ccce23829056d4736f8c492 to your computer and use it in GitHub Desktop.
Save oarriaga/f070d9bd0ccce23829056d4736f8c492 to your computer and use it in GitHub Desktop.
LSTM Reference Implementation in Python
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"from __future__ import division, print_function, unicode_literals\n",
"import numpy as np"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# LSTM Reference implementation in Numpy\n",
"\n",
"This implementation is meant as a reference for understanding and to check other implementations. \n",
"The figures and formulas are taken from [\"LSTM: A Search Space Odyssey\"](http://arxiv.org/abs/1503.04069).\n",
"It is not optimized for speed or memory consumption in any way."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We use the following activation functions (all pointwise on vector inputs):\n",
"\n",
" * The logistic sigmoid for all the gates: $\\sigma(\\mathbf{x}) = \\frac{1}{1+e^{-\\mathbf{x}}}$ \n",
" * hyperbolic tangent for block input and output: $g(\\mathbf{x}) = h(\\mathbf{x}) = \\text{tanh}(\\mathbf{x}) = \\frac{e^\\mathbf{x} - e^{-\\mathbf{x}}}{e^\\mathbf{x} + e^{-\\mathbf{x}}}$\n",
" \n",
"with the corresponing derivatives:\n",
" * $\\sigma'(\\mathbf{x}) = \\sigma(\\mathbf{x}) \\odot (\\mathbf{1} - \\sigma(\\mathbf{x})) $\n",
" * $\\text{tanh}'(\\mathbf{x}) = \\mathbf{1} - \\text{tanh}(\\mathbf{x})^2$"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"sigma = lambda x: 1./(1 + np.exp(-x))\n",
"sigma_deriv = lambda x: sigma(x) * (1 - sigma(x))\n",
"\n",
"g = lambda x: np.tanh(x)\n",
"g_deriv = lambda x: 1 - g(x)**2\n",
"\n",
"h = lambda x: np.tanh(x)\n",
"h_deriv = lambda x: 1 - h(x)**2"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Weights\n",
"Let $N$ be the number of LSTM blocks and $M$ the number of inputs. Then we get the following weights:\n",
"\n",
" * Input weights: $\\mathbf{W}_z$, $\\mathbf{W}_s$, $\\mathbf{W}_f$, $\\mathbf{W}_o$ $\\in \\mathbb{R}^{M \\times N}$\n",
" * Recurrent weights: $\\mathbf{R}_z$, $\\mathbf{R}_s$, $\\mathbf{R}_f$, $\\mathbf{R}_o$ $\\in \\mathbb{R}^{N \\times N}$\n",
" * Peephole weights: $\\mathbf{p}_s$, $\\mathbf{p}_f$, $\\mathbf{p}_o$ $\\in \\mathbb{R}^{N}$\n",
" * Bias weights: $\\mathbf{b}_z$, $\\mathbf{b}_s$, $\\mathbf{b}_f$, $\\mathbf{b}_o$ $\\in \\mathbb{R}^{N}$\n",
"\n",
"Where the subscript letters stand for the block input ($z$), the input gate ($i$), the forget gate ($f$), and the output gate($o$). "
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"M = 2 # dimensionality of inputs\n",
"N = 3 # number of LSTM blocks\n",
"\n",
"rnd = np.random.RandomState()\n",
"\n",
"Wz = rnd.randn(M, N) * 0.1\n",
"Wi = rnd.randn(M, N) * 0.1\n",
"Wf = rnd.randn(M, N) * 0.1\n",
"Wo = rnd.randn(M, N) * 0.1\n",
"\n",
"Rz = rnd.randn(N, N) * 0.1\n",
"Ri = rnd.randn(N, N) * 0.1\n",
"Rf = rnd.randn(N, N) * 0.1\n",
"Ro = rnd.randn(N, N) * 0.1\n",
"\n",
"pi = rnd.randn(N) * 0.1\n",
"pf = rnd.randn(N) * 0.1\n",
"po = rnd.randn(N) * 0.1\n",
"\n",
"bz = rnd.randn(N) * 0.1\n",
"bi = rnd.randn(N) * 0.1\n",
"bf = rnd.randn(N) * 0.1\n",
"bo = rnd.randn(N) * 0.1\n",
"\n",
"weights = (Wz, Wi, Wf, Wo, Rz, Ri, Rf, Ro, pi, pf, po, bz, bi, bf, bo)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Forward pass\n",
"\n",
"<img src=\"http://people.idsia.ch/~greff/lstm.svg\" alt=\"LSTM\" style=\"float:right;\" />\n",
"\n",
"Formulas:\n",
"\n",
"$\\bar{\\mathbf{z}}^t = \\mathbf{x}^t \\mathbf{W}_z + \\mathbf{y}^{t-1} \\mathbf{R}_z + \\mathbf{b}_z $<br>\n",
"$\\mathbf{z}^t = g(\\bar{\\mathbf{z}}^t) $\n",
"*(block input)*\n",
"\n",
"$\\bar{\\mathbf{i}}^t = \\mathbf{x}^t \\mathbf{W}_i + \\mathbf{y}^{t-1} \\mathbf{R}_i + \\mathbf{p}_i \\odot \\mathbf{c}^{t-1} + \\mathbf{b}_i $<br>\n",
"$\\mathbf{i}^t = \\sigma(\\bar{\\mathbf{i}}^t) $\n",
"*(input gate)*\n",
" \n",
"$\\bar{\\mathbf{f}}^t = \\mathbf{x}^t \\mathbf{W}_f + \\mathbf{y}^{t-1} \\mathbf{R}_f + \\mathbf{p}_f \\odot \\mathbf{c}^{t-1} + \\mathbf{b}_f $<br>\n",
"$\\mathbf{f}^t = \\sigma(\\bar{\\mathbf{f}}^t) $\n",
"*(forget gate)*\n",
" \n",
"$\\mathbf{c}^t = \\mathbf{z}^t \\odot \\mathbf{i}^t + \\mathbf{c}^{t-1} \\odot \\mathbf{f}^t$\n",
"*(cell state)*\n",
"\n",
"$\\bar{\\mathbf{o}}^t = \\mathbf{x}^t \\mathbf{W}_o + \\mathbf{y}^{t-1} \\mathbf{R}_o + \\mathbf{p}_o \\odot \\mathbf{c}^{t} + \\mathbf{b}_o$<br>\n",
"$\\mathbf{o}^t = \\sigma(\\bar{\\mathbf{o}}^t)$\n",
"*(output gate)*\n",
"\n",
"$\\mathbf{y}^t = h(\\mathbf{c}^t) \\odot \\mathbf{o}^t$\n",
"*(block output)*\n"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"def forward(x, weights):\n",
" Wz, Wi, Wf, Wo, Rz, Ri, Rf, Ro, pi, pf, po, bz, bi, bf, bo = weights\n",
" N = Rz.shape[0] # nr hidden units\n",
" T, M = x.shape\n",
" za = np.zeros([T, N])\n",
" z = np.zeros([T, N])\n",
" ia = np.zeros([T, N])\n",
" i = np.zeros([T, N])\n",
" fa = np.zeros([T, N])\n",
" f = np.zeros([T, N])\n",
" c = np.zeros([T, N])\n",
" oa = np.zeros([T, N])\n",
" o = np.zeros([T, N])\n",
" y = np.zeros([T, N])\n",
" \n",
" # the t-1 indexing will automatically wrap and access the last timestep which is zero\n",
" for t in range(T):\n",
" za[t] = np.dot(x[t], Wz) + np.dot(y[t-1], Rz) + bz\n",
" z[t] = g(za[t])\n",
"\n",
" ia[t] = np.dot(x[t], Wi) + np.dot(y[t-1], Ri) + pi*c[t-1] + bi\n",
" i[t] = sigma(ia[t])\n",
"\n",
" fa[t] = np.dot(x[t], Wf) + np.dot(y[t-1], Rf) + pf*c[t-1] + bf\n",
" f[t] = sigma(fa[t])\n",
"\n",
" c[t] = i[t] * z[t] + f[t] * c[t-1]\n",
"\n",
" oa[t] = np.dot(x[t], Wo) + np.dot(y[t-1], Ro) + po*c[t] + bo\n",
" o[t] = sigma(oa[t])\n",
"\n",
" y[t] = o[t] * h(c[t])\n",
" \n",
" fwd_state = (za, z, ia, i, fa, f, oa, o, c, y)\n",
" return y, fwd_state"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Backward Pass\n",
"Let $\\Delta^t$ be the vector of deltas received from above. Formally they are $\\frac{\\partial E}{\\partial \\mathbf{y}^t}$ but not including the recurrent dependencies. We'll resolve those in $\\mathbf{\\delta y}^t$:\n",
"\n",
"\n",
"$\\mathbf{\\delta y}^t = \\Delta^t + \\mathbf{\\delta z}^{t+1} \\mathbf{R}_z^T + \n",
" \\mathbf{\\delta i}^{t+1} \\mathbf{R}_i^T + \n",
" \\mathbf{\\delta f}^{t+1} \\mathbf{R}_f^T + \n",
" \\mathbf{\\delta o}^{t+1} \\mathbf{R}_o^T$\n",
"\n",
"$\\mathbf{\\delta o}^t = \\mathbf{\\delta y}^t \\odot h(\\mathbf{c}^t) \\odot \\sigma'(\\bar{\\mathbf{o}}^t) $\n",
"\n",
"$\\mathbf{\\delta c}^t = \\mathbf{\\delta y}^t \\odot \\mathbf{o}^t \\odot h'(\\mathbf{c}^t) + \n",
" \\mathbf{p}_o \\odot \\mathbf{\\delta o}^t +\n",
" \\mathbf{p}_i \\odot \\mathbf{\\delta i}^{t+1} +\n",
" \\mathbf{p}_f \\odot \\mathbf{\\delta f}^{t+1} +\n",
" \\mathbf{\\delta c}^{t+1} \\odot \\mathbf{f}^{t+1}$ \n",
" \n",
"$\\mathbf{\\delta f}^t = \\mathbf{\\delta c}^t \\odot \\mathbf{c}^{t-1} \\odot \\sigma'(\\bar{\\mathbf{f}}^t) $\n",
"\n",
"$\\mathbf{\\delta i}^t = \\mathbf{\\delta c}^t \\odot \\mathbf{z}^{t} \\odot \\sigma'(\\bar{\\mathbf{i}}^t) $\n",
"\n",
"$\\mathbf{\\delta z}^t = \\mathbf{\\delta c}^t \\odot \\mathbf{i}^{t} \\odot g'(\\bar{\\mathbf{z}}^t) $\n",
"\n",
"\n",
"Deltas for the inputs. Only needed if there is a layer below that needs training:<br>\n",
"$\\mathbf{\\delta x}^t = \\mathbf{\\delta z}^t \\mathbf{W}_z^T + \n",
" \\mathbf{\\delta i}^t \\mathbf{W}_i^T + \n",
" \\mathbf{\\delta f}^t \\mathbf{W}_f^T + \n",
" \\mathbf{\\delta o}^t \\mathbf{W}_o^T$\n",
"\n",
"Gradients for the weights:<br>\n",
"$\\delta\\mathbf{W}_\\star = \\sum^T_{t=0} \\langle \\mathbf{\\delta\\star}^t, \\mathbf{x}^t \\rangle$\n",
"\n",
"$\\delta\\mathbf{R}_\\star = \\sum^{T-1}_{t=0} \\langle \\mathbf{\\delta\\star}^{t+1}, \\mathbf{y}^t \\rangle$\n",
"\n",
"$\\delta\\mathbf{p}_i = \\sum^{T-1}_{t=0} \\mathbf{c}^t \\odot \\mathbf{\\delta i}^{t+1}$<br>\n",
"$\\delta\\mathbf{p}_f = \\sum^{T-1}_{t=0} \\mathbf{c}^t \\odot \\mathbf{\\delta f}^{t+1}$<br>\n",
"$\\delta\\mathbf{p}_o = \\sum^{T}_{t=0} \\mathbf{c}^t \\odot \\mathbf{\\delta o}^{t}$\n",
"\n",
"$\\delta\\mathbf{b}_\\star = \\sum^{T}_{t=0} \\mathbf{\\delta\\star}^{t}$\n"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"def backward(deltas, weights, fwd_state):\n",
" Wz, Wi, Wf, Wo, Rz, Ri, Rf, Ro, pi, pf, po, bz, bi, bf, bo = weights\n",
" za, z, ia, i, fa, f, oa, o, c, y = fwd_state\n",
" T, N = deltas.shape\n",
" M = Wz.shape[0]\n",
" \n",
" # we make the derivative arrays longer so t+1 is automatically zero at the ends\n",
" dz = np.zeros([T+1, N])\n",
" di = np.zeros([T+1, N])\n",
" df = np.zeros([T+1, N])\n",
" dc = np.zeros([T+1, N])\n",
" do = np.zeros([T+1, N])\n",
" dy = np.zeros([T, N])\n",
" dx = np.zeros([T, M])\n",
" \n",
" # initialize gradients\n",
" gradients = [np.zeros_like(w) for w in weights]\n",
" dWz, dWi, dWf, dWo, dRz, dRi, dRf, dRo, dpi, dpf, dpo, dbz, dbi, dbf, dbo = gradients\n",
" \n",
" for t in reversed(range(T)):\n",
" dy[t] = deltas[t] + np.dot(di[t+1], Ri.T) + np.dot(df[t+1], Rf.T) +\\\n",
" np.dot(do[t+1], Ro.T) + np.dot(dz[t+1], Rz.T)\n",
" do[t] = dy[t] * h(c[t]) * sigma_deriv(oa[t])\n",
" dc[t] = dy[t] * o[t] * h_deriv(c[t]) + po * do[t]\n",
" if t < T-1:\n",
" dc[t] += pi * di[t+1] + pf * df[t+1] + dc[t+1] * f[t+1]\n",
" if t > 0:\n",
" df[t] = dc[t] * c[t-1] * sigma_deriv(fa[t])\n",
" di[t] = dc[t] * z[t] * sigma_deriv(ia[t])\n",
" dz[t] = dc[t] * i[t] * g_deriv(za[t])\n",
" \n",
" # Input Deltas\n",
" dx[t] = np.dot(dz[t], Wz.T) + np.dot(di[t], Wi.T) + np.dot(df[t], Wf.T) + np.dot(do[t], Wo.T)\n",
" \n",
" # Gradients for the weights\n",
" dWz += np.outer(x[t], dz[t])\n",
" dWi += np.outer(x[t], di[t])\n",
" dWf += np.outer(x[t], df[t])\n",
" dWo += np.outer(x[t], do[t])\n",
" dRz += np.outer(y[t], dz[t+1])\n",
" dRi += np.outer(y[t], di[t+1])\n",
" dRf += np.outer(y[t], df[t+1])\n",
" dRo += np.outer(y[t], do[t+1]) \n",
" dpi += c[t] * di[t+1]\n",
" dpf += c[t] * df[t+1]\n",
" dpo += c[t] * do[t]\n",
" dbz += dz[t]\n",
" dbi += di[t]\n",
" dbf += df[t]\n",
" dbo += do[t]\n",
"\n",
" bwd_state = [dz, di, df, dc, do, dy]\n",
" return dx, gradients, bwd_state"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Finite Differences Checking \n",
"Let's check the gradients using a numerical approximation, to make sure we didn't do any mistakes."
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"# As loss function we use Squared Error\n",
"SE = lambda y, t: 0.5 * np.sum((y-t)**2)\n",
"SE_deriv = lambda y, t: y - t"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"def finite_diff(f, initial_x, eps=1e-10):\n",
" err1 = f(initial_x)\n",
" delta = np.zeros_like(initial_x)\n",
" diff = np.zeros_like(initial_x)\n",
" for i in range(delta.size):\n",
" delta.flat[i] = eps\n",
" err2 = f(initial_x + delta)\n",
" diff.flat[i] = (err2 - err1) / eps\n",
" delta.flat[i] = 0\n",
" return diff"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"T = 10 # nr_timesteps\n",
"x = rnd.randn(T, M)\n",
"targets = rnd.randn(T, N)"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"y, fwd_state = forward(x, weights)\n",
"deltas = SE_deriv(y, targets)\n",
"dx, gradients, bwd_state = backward(deltas, weights, fwd_state)"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"def func1(inputs):\n",
" return SE(forward(inputs, weights)[0], targets)"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"dx_approx = finite_diff(func1, x)"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {
"collapsed": false
},
"outputs": [
{
"data": {
"text/plain": [
"4.4914192362054091e-09"
]
},
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"SE(dx, dx_approx)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"To make checking the weight-gradients easier, we will place all the weights in one large array. \n",
"So every array in the weights list is actually a slice of the larger array. "
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"weight_sizes = [w.size for w in weights]\n",
"total_nr_weights = sum(weight_sizes)\n",
"split_idx = np.hstack(([0], np.cumsum(weight_sizes)))\n",
"shapes = [w.shape for w in weights]"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"def split_weights(all_weights):\n",
" return [all_weights[i:j].reshape(s) for i, j, s in zip(split_idx[:-1], split_idx[1:], shapes)]"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"def func2(allweights):\n",
" weights = split_weights(allweights)\n",
" return SE(forward(x, weights)[0], targets)"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"all_weights = rnd.randn(total_nr_weights)\n",
"weights = split_weights(all_weights)"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"grad_approx = finite_diff(func2, all_weights)"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"y, fwd_state = forward(x, weights )\n",
"deltas = y - targets\n",
"dx, gradients, bwd_state = backward(deltas, weights, fwd_state)"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Wz = 1.6018560246e-09\n",
"Wi = 8.05556588901e-10\n",
"Wf = 1.98769681238e-09\n",
"Wo = 1.18538571315e-09\n",
"Rz = 4.39003824181e-09\n",
"Ri = 2.24740426348e-09\n",
"Rf = 2.8298336655e-09\n",
"Ro = 2.09207265237e-09\n",
"pi = 6.86448309643e-10\n",
"pf = 1.06052334981e-10\n",
"po = 5.34335201749e-10\n",
"bz = 4.52704935075e-10\n",
"bi = 3.14392932658e-10\n",
"bf = 1.20972654246e-10\n",
"bo = 1.33090539717e-10\n"
]
}
],
"source": [
"weight_names = ['Wz', 'Wi', 'Wf', 'Wo', 'Rz', 'Ri', 'Rf', 'Ro', 'pi', 'pf', 'po', 'bz', 'bi', 'bf', 'bo']\n",
"total_err = 0\n",
"\n",
"for g_calc, g_approx, name in zip(gradients, split_weights(grad_approx), weight_names):\n",
" print(name, '=', SE(g_calc, g_approx))\n",
" "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 2",
"language": "python",
"name": "python2"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 2
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2",
"version": "2.7.6"
}
},
"nbformat": 4,
"nbformat_minor": 0
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment