Last active
May 10, 2020 00:38
-
-
Save damienpontifex/74561a9e6bf43b59b813e5487257aa91 to your computer and use it in GitHub Desktop.
Getting my head around TensorFlow RNN inputs, outputs and the appropriate shapes
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"import numpy as np\n", | |
"import tensorflow as tf" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"# `dynamic_rnn`" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"metadata": { | |
"collapsed": true, | |
"scrolled": false | |
}, | |
"outputs": [], | |
"source": [ | |
"tf.reset_default_graph()\n", | |
"\n", | |
"# Values is data batch_size=2, sequence_length = 3, num_features = 1\n", | |
"values = tf.constant(np.array([\n", | |
" [[1], [2], [3]],\n", | |
" [[2], [3], [4]]\n", | |
"]), dtype=tf.float32)\n", | |
"\n", | |
"lstm_cell = tf.contrib.rnn.LSTMCell(100)\n", | |
"\n", | |
"outputs, state = tf.nn.dynamic_rnn(cell=lstm_cell, dtype=tf.float32, inputs=values)\n", | |
"\n", | |
"with tf.Session() as sess:\n", | |
" sess.run(tf.global_variables_initializer())\n", | |
" output_run, state_run = sess.run([outputs, state])" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"True" | |
] | |
}, | |
"execution_count": 3, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"np.all(output_run[:,-1] == state_run.h)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"<tf.Tensor 'rnn/transpose:0' shape=(2, 3, 100) dtype=float32>" | |
] | |
}, | |
"execution_count": 4, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"outputs" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"<tf.Tensor 'rnn/while/Exit_2:0' shape=(2, 100) dtype=float32>" | |
] | |
}, | |
"execution_count": 5, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"state.c" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"<tf.Tensor 'rnn/while/Exit_3:0' shape=(2, 100) dtype=float32>" | |
] | |
}, | |
"execution_count": 6, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"state.h" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"# `bidirectional_dynamic_rnn`" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"metadata": { | |
"collapsed": true, | |
"scrolled": false | |
}, | |
"outputs": [], | |
"source": [ | |
"tf.reset_default_graph()\n", | |
"\n", | |
"# Values is data batch_size=2, sequence_length = 3, num_features = 1\n", | |
"values = tf.constant(np.array([\n", | |
" [[1], [2], [3]],\n", | |
" [[2], [3], [4]]\n", | |
"]), dtype=tf.float32)\n", | |
"\n", | |
"lstm_cell_fw = tf.contrib.rnn.LSTMCell(100)\n", | |
"lstm_cell_bw = tf.contrib.rnn.LSTMCell(105) # change to 105 just so can see the effect in output\n", | |
"\n", | |
"(output_fw, output_bw), (output_state_fw, output_state_bw) = tf.nn.bidirectional_dynamic_rnn(\n", | |
" cell_fw=lstm_cell_fw, \n", | |
" cell_bw=lstm_cell_bw, \n", | |
" inputs=values,\n", | |
" dtype=tf.float32)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 8, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"<tf.Tensor 'bidirectional_rnn/fw/fw/transpose:0' shape=(2, 3, 100) dtype=float32>" | |
] | |
}, | |
"execution_count": 8, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"output_fw" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 9, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"<tf.Tensor 'ReverseV2:0' shape=(2, 3, 105) dtype=float32>" | |
] | |
}, | |
"execution_count": 9, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"output_bw" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 10, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"<tf.Tensor 'bidirectional_concat_outputs:0' shape=(2, 3, 205) dtype=float32>" | |
] | |
}, | |
"execution_count": 10, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"outputs = tf.concat((output_fw, output_bw), axis=2, name='bidirectional_concat_outputs')\n", | |
"outputs" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 11, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"<tf.Tensor 'bidirectional_rnn/fw/fw/while/Exit_2:0' shape=(2, 100) dtype=float32>" | |
] | |
}, | |
"execution_count": 11, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"output_state_fw.c" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 12, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"<tf.Tensor 'bidirectional_rnn/fw/fw/while/Exit_3:0' shape=(2, 100) dtype=float32>" | |
] | |
}, | |
"execution_count": 12, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"output_state_fw.h" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 13, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"<tf.Tensor 'bidirectional_rnn/bw/bw/while/Exit_2:0' shape=(2, 105) dtype=float32>" | |
] | |
}, | |
"execution_count": 13, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"output_state_bw.c" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 14, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"<tf.Tensor 'bidirectional_rnn/bw/bw/while/Exit_3:0' shape=(2, 105) dtype=float32>" | |
] | |
}, | |
"execution_count": 14, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"output_state_bw.h" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 15, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"<tf.Tensor 'bidirectional_concat_memory_cell:0' shape=(2, 205) dtype=float32>" | |
] | |
}, | |
"execution_count": 15, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"tf.concat((output_state_fw.c, output_state_bw.c), axis=1, name='bidirectional_concat_memory_cell')" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 16, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"<tf.Tensor 'bidirectional_concat_hidden_state:0' shape=(2, 205) dtype=float32>" | |
] | |
}, | |
"execution_count": 16, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"tf.concat((output_state_fw.h, output_state_bw.h), axis=1, name='bidirectional_concat_hidden_state')" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"# GRU" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 17, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"tf.reset_default_graph()\n", | |
"# Values is data batch_size=2, sequence_length = 3, num_features = 1\n", | |
"values = tf.constant(np.array([\n", | |
" [[1], [2], [3]],\n", | |
" [[2], [3], [4]]\n", | |
"]), dtype=tf.float32)\n", | |
"gru_cell = tf.contrib.rnn.GRUCell(100)\n", | |
"outputs, state = tf.nn.dynamic_rnn(cell=gru_cell, dtype=tf.float32, inputs=values)\n", | |
"\n", | |
"with tf.Session() as sess:\n", | |
" sess.run(tf.global_variables_initializer())\n", | |
" output_run, state_run = sess.run([outputs, state])" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 18, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"True" | |
] | |
}, | |
"execution_count": 18, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"np.all(output_run[:,-1] == state_run)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 19, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"<tf.Tensor 'rnn/transpose:0' shape=(2, 3, 100) dtype=float32>" | |
] | |
}, | |
"execution_count": 19, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"outputs" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 20, | |
"metadata": { | |
"scrolled": true | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"<tf.Tensor 'rnn/while/Exit_2:0' shape=(2, 100) dtype=float32>" | |
] | |
}, | |
"execution_count": 20, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"state" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"# Multi RNN" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 21, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"tf.reset_default_graph()\n", | |
"\n", | |
"# Values is data batch_size=2, sequence_length = 3, num_features = 1\n", | |
"values = tf.constant(np.array([\n", | |
" [[1], [2], [3]],\n", | |
" [[2], [3], [4]]\n", | |
"]), dtype=tf.float32)\n", | |
"\n", | |
"lstm_cell = lambda: tf.contrib.rnn.LSTMCell(100)\n", | |
"multi_cell = tf.contrib.rnn.MultiRNNCell([lstm_cell() for _ in range(3)])\n", | |
"\n", | |
"outputs, state = tf.nn.dynamic_rnn(cell=multi_cell, dtype=tf.float32, inputs=values)\n", | |
"\n", | |
"with tf.Session() as sess:\n", | |
" sess.run(tf.global_variables_initializer())\n", | |
" output_run, state_run = sess.run([outputs, state])" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 22, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"True" | |
] | |
}, | |
"execution_count": 22, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"np.all(output_run[:,-1] == state_run[-1].h)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 23, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"<tf.Tensor 'rnn/transpose:0' shape=(2, 3, 100) dtype=float32>" | |
] | |
}, | |
"execution_count": 23, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"outputs" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 24, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"(LSTMStateTuple(c=<tf.Tensor 'rnn/while/Exit_2:0' shape=(2, 100) dtype=float32>, h=<tf.Tensor 'rnn/while/Exit_3:0' shape=(2, 100) dtype=float32>),\n", | |
" LSTMStateTuple(c=<tf.Tensor 'rnn/while/Exit_4:0' shape=(2, 100) dtype=float32>, h=<tf.Tensor 'rnn/while/Exit_5:0' shape=(2, 100) dtype=float32>),\n", | |
" LSTMStateTuple(c=<tf.Tensor 'rnn/while/Exit_6:0' shape=(2, 100) dtype=float32>, h=<tf.Tensor 'rnn/while/Exit_7:0' shape=(2, 100) dtype=float32>))" | |
] | |
}, | |
"execution_count": 24, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"state" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.6.3" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Your tutorial is a life saver. Brilliant and point on!