Created
September 17, 2018 13:49
-
-
Save cjauvin/5ac17dbc4665a2719f7d78bf1d51c450 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 13, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import tensorflow as tf\n", | |
"\n", | |
"tf.reset_default_graph()\n", | |
"\n", | |
"M = 10 # minibatch size\n", | |
"K = 10 # rows: number of z components\n", | |
"N = 10 # columns: number of possible discrete states for a z component \n", | |
"\n", | |
"X_a = tf.placeholder(tf.float32, [M, 14 * 28]) # half image\n", | |
"hidden1_a = tf.contrib.layers.fully_connected(X_a, 300, activation_fn=tf.nn.relu)\n", | |
"hidden2_a = tf.contrib.layers.fully_connected(hidden1_a, 100, activation_fn=tf.nn.relu)\n", | |
"logits_a = tf.reshape(hidden2_a, [M, K, N])\n", | |
"\n", | |
"X_b = tf.placeholder(tf.float32, [M, 14 * 28]) # half image\n", | |
"hidden1_b = tf.contrib.layers.fully_connected(X_b, 300, activation_fn=tf.nn.relu)\n", | |
"hidden2_b = tf.contrib.layers.fully_connected(hidden1_b, 100, activation_fn=tf.nn.relu)\n", | |
"logits_b = tf.reshape(hidden2_b, [M, K, N])\n", | |
"\n", | |
"#logits_a = tf.Variable(tf.random_uniform(shape=(M, K, N)))\n", | |
"#logits_b = tf.Variable(tf.random_uniform(shape=(M, K, N)))\n", | |
"logits_p = tf.Variable(tf.random_uniform(shape=(K, N)))\n", | |
"\n", | |
"# These can be precomputed\n", | |
"lse_a = tf.nn.log_softmax(logits_a, axis=1)\n", | |
"lse_b = tf.nn.log_softmax(logits_b, axis=1)\n", | |
"lse_p = tf.nn.log_softmax(logits_p, axis=1)\n", | |
"\n", | |
"def agreement(i, j):\n", | |
" a = logits_a[i] - lse_a[i] + logits_b[j] - lse_b[j] - logits_p + lse_p # (K, N) matrix\n", | |
" # LSE-reduce rows first, and then reduce the resulting k-size vector to a scalar\n", | |
" return tf.reduce_sum(tf.reduce_logsumexp(a, axis=1))\n", | |
"\n", | |
"# First term of the Holy Loss: \\sum_n A(La, Lb), for all diag pairs (there are M)\n", | |
"\n", | |
"c = lambda i, a: i < M\n", | |
"b = lambda i, a: [i + 1, a + agreement(i, i)] # accumulate agreement values\n", | |
"_, holy_loss_first_term = tf.while_loop(c, b, [0, 0.])\n", | |
"holy_loss_first_term = -holy_loss_first_term / M\n", | |
"\n", | |
"# Second term of the Holy Loss: LSE A(La, Lb) over all non-diag pairs (there are M / (M - 1)) \n", | |
"\n", | |
"def inner_while_loop(i, r):\n", | |
" c = lambda j, ea: j < M\n", | |
" # accumulate exp(agreement) values (if non-diag)\n", | |
" b = lambda j, ea: [j + 1, ea + tf.cond(tf.not_equal(i, j), lambda: tf.exp(agreement(i, j)), lambda: 0.)]\n", | |
" _, t = tf.while_loop(c, b, [0, r])\n", | |
" return i + 1, t\n", | |
"\n", | |
"c = lambda i, ea: i < M\n", | |
"_, holy_loss_second_term = tf.while_loop(c, inner_while_loop, [0, 0.])\n", | |
"holy_loss_second_term = tf.log(holy_loss_second_term) - tf.log(M / (M - 1))\n", | |
"\n", | |
"# The full Holy Loss\n", | |
"holy_loss = holy_loss_first_term + holy_loss_second_term\n", | |
"\n", | |
"#train_op = tf.train.AdamOptimizer().minimize(holy_loss)\n", | |
"train_op = tf.train.GradientDescentOptimizer(0.01).minimize(holy_loss)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 12, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"<matplotlib.image.AxesImage at 0xb39d43780>" | |
] | |
}, | |
"execution_count": 12, | |
"metadata": {}, | |
"output_type": "execute_result" | |
}, | |
{ | |
"data": { | |
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAUgAAAD8CAYAAAAVOD3kAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAADbxJREFUeJzt3W+IXYWZx/HfL9pASFWMbqaDDRrWuG4oMYUxrAQ0JTakUoh9oVSwBFYyReqLwr4w6IsWZEHB7dIXZTW1w6SY2hbiYF7UtiEvlMVNyEyRmpjESEjbMSFpsLJZFUv02RdzxoxxnnNu7r9z753vB+Tee55z7nnQx1/OOTn3XkeEAACft6juBgCgVxGQAJAgIAEgQUACQIKABIAEAQkACQISABIEJAAkCEgASFzZzZ3Z5mM7/edcRPxD3U30Oma7/0SEq9Zp6QjS9mbbx2y/bXt7K++FnvWnuhuoA7MNqYWAtH2FpJ9I+oak1ZIesL26XY0BdWG2MauVI8h1kt6OiBMR8XdJv5S0pT1tAbVitiGptYC8QdJf5ryeLpZ9hu1R25O2J1vYF9BNzDYktfaXNPNd4PzcheqI2CFph8SFbPQNZhuSWjuCnJa0Ys7rL0s61Vo7QE9gtiGptYA8KGmV7ZW2F0v6tqQ97WkLqBWzDUktnGJHxAXbj0j6naQrJI1FxOG2dQbUhNnGLHfzJxe4TtOXpiJipO4meh2z3X86fqM4AAwyAhIAEgQkACQISABIEJAAkCAgASBBQAJAgoAEgAQBCQAJAhIAEgQkACQISABIEJAAkCAgASBBQAJAgoAEgAQBCQAJAhIAEgQkACQISABIEJAAkCAgASBBQAJAgoAEgAQBCQAJAhIAEgQkACQISABIEJAAkLiylY1tn5R0XtLHki5ExEg7mgLqtpBme+PGjWlt165dpdveddddpfVjx4411VOvaCkgC1+LiHNteB+g1zDbCxyn2ACQaDUgQ9LvbU/ZHm1HQ0CPYLbR8in2+og4ZXu5pL22j0bEq3NXKIaLAUO/YbbR2hFkRJwqHs9KmpC0bp51dkTEyCBf5MbgYbYhtRCQtpfavmr2uaRNkg61qzGgLsw2ZrVyij0kacL27Pv8IiJ+25augHox25DUQkBGxAlJt7Wxl4668847S+vXXXddaX1iYqKd7aCH9dtst+r2229PawcPHuxiJ72H23wAIEFAAkCCgASABAEJAAkCEgASBCQAJNrxbT59YcOGDaX1VatWlda5zQf9atGi8uOglStXprUbb7yxdNviXtGBxREkACQISABIEJAAkCAgASBBQAJAgoAEgAQBCQCJBXMf5NatW0vrr732Wpc6AbpreHi4tL5t27a09vzzz5due/To0aZ66hccQQJAgoAEgAQBCQAJAhIAEgQkACQISABIEJAAkFgw90FWfSceMKiee+65prc9fvx4GzvpP6QGACQISABIEJAAkCAgASBBQAJAgoAEgAQBCQCJyvsgbY9J+qaksxHxlWLZMkm/knSTpJOS7o+Iv3WuzWpr1qwprQ8NDXWpE/SLfpntVl1zzTVNb7t37942dtJ/GjmCHJe0+ZJl2yXti4hVkvYVr4F+My5mGyUqAzIiXpX07iWLt0jaWTzfKeneNvcFdByzjSrNXoMciojTklQ8Lm9fS0CtmG18quOfxbY9Kmm00/sBuo3ZHnzNHkGesT0sScXj2WzFiNgRESMRMdLkvoBuYrbxqWYDco+k2Z8J3Crppfa0A9SO2canKgPS9guS/kfSP9metv2QpCclfd32cUlfL14DfYXZRpXKa5AR8UBS2tjmXlpyzz33lNaXLFnSpU7QL/pltqtU3eO7cuXKpt/7nXfeaXrbQcAnaQAgQUACQIKABIAEAQkACQISABIEJAAkBuZnX2+99daWtj98+HCbOgG66+mnny6tV90G9NZbb6W18+fPN9XToOAIEgASBCQAJAhIAEgQkACQICABIEFAAkCCgASAxMDcB9mqgwcP1t0CBtjVV19dWt+8+dIfV7zowQcfLN1206ZNTfU064knnkhr7733Xkvv3e84ggSABAEJAAkCEgASBCQAJAhIAEgQkACQICABIMF9kIVly5bVtu/bbruttG67tH733XentRUrVpRuu3jx4tL6ww8/XFrHjLVr1+qVV15J64sWlR+LfPjhh2ntwIEDpdt+9NFHpfUrryz/33xqaqq0vpBxBAkACQISABIEJAAkCEgASBCQAJAgIAEgQUACQKLyPkjbY5K+KelsRHylWPZDSdsk/bVY7bGI+E2nmmxE2X1kkhQRpfVnn322tP74449fdk+NWrNmTWm96j7ICxcupLUPPvigdNs333yztD7I2jnb586d09jYWFqfnJws3b7sHsozZ86Ubjs9PV1aX7JkSWn96NGjpfWFrJEjyHFJ832b539GxNrin1rDEWjSuJhtlKgMyIh4VdK7XegF6CpmG1VauQb5iO0/2h6zfW3bOgLqx2xDUvMB+V+S/lHSWkmnJf1HtqLtUduTtssvwgC9oanZfv/997vVH7qoqYCMiDMR8XFEfCLpp5LWlay7IyJGImKk2SaBbml2tpcuXdq9JtE1TQWk7eE5L78l6VB72gHqxWxjrkZu83lB0gZJ19uelvQDSRtsr5UUkk5K+m4HewQ6gtlGFVfdH9jWndnd29klHn300dL6+vXru9TJ5ZuYmCitHzlyJK3t37+/1d1PcXmkWidne3R0tLT+zDPPlNZPnDhRWr/55psvu6dBEBHlNxiLT9IAQIqABIAEAQkACQISABIEJAAkCEgASCyYn3196qmn6m4BaMrGjRtb2n737t1t6mTh4QgSABIEJAAkCEgASBCQAJAgIAEgQUACQIKABIDEgrkPElioqr4uDzmOIAEgQUACQIKABIAEAQkACQISABIEJAAkCEgASBCQAJAgIAEgQUACQIKABIAEAQkACQISABIEJAAkCEgASFR+H6TtFZJ+LulLkj6RtCMifmx7maRfSbpJ0klJ90fE3zrXKtBegzLbtkvrt9xyS2l9//797WxnoDRyBHlB0r9FxD9L+hdJ37O9WtJ2SfsiYpWkfcVroJ8w2yhVGZARcToi/lA8Py/piKQbJG2RtLNYbaekezvVJNAJzDaqXNY1SNs3SfqqpAOShiLitDQzaJKWt7s5oFuYbcyn4d+ksf1FSbslfT8i/rfqusec7UYljTbXHtB5zDYyDR1B2v6CZgZoV0S8WCw+Y3u4qA9LOjvfthGxIyJGImKkHQ0D7cRso0xlQHrmj9OfSToSET+aU9ojaWvxfKukl9rfHtA5zDaqNHKKvV7SdyS9Yfv1Ytljkp6U9GvbD0n6s6T7OtMi0DEDMdsRUVpftIjbnZtVGZAR8d+SsosyG9vbDtA9zDaq8EcLACQISABIEJAAkCAgASBBQAJAgoAEgETDHzUE0J/uuOOO0vr4+Hh3GulDHEECQIKABIAEAQkACQISABIEJAAkCEgASBCQAJDgPkigzzX6ExG4fBxBAkCCgASABAEJAAkCEgASBCQAJAhIAEgQkACQ4D5IoMe9/PLLpfX77uvpn+3uaxxBAkCCgASABAEJAAkCEgASBCQAJAhIAEgQkACQcESUr2CvkPRzSV+S9ImkHRHxY9s/lLRN0l+LVR+LiN9UvFf5ztCLpiJipO4mOoHZXtgiovKLNBsJyGFJwxHxB9tXSZqSdK+k+yX9X0Q83WhDDFFfGuSAZLYXsEYCsvKTNBFxWtLp4vl520ck3dB6e0C9mG1UuaxrkLZvkvRVSQeKRY/Y/qPtMdvXJtuM2p60PdlSp0AHMduYT+Up9qcr2l+U9Iqkf4+IF20PSTonKSQ9oZlTlX+teA9OQ/rPwJ5iz2K2F6ZGTrEbOoK0/QVJuyXtiogXizc/ExEfR8Qnkn4qaV0rzQJ1YLZRpjIgPfOTaT+TdCQifjRn+fCc1b4l6VD72wM6h9lGlUa+7my9pO9IesP268WyxyQ9YHutZk5DTkr6bkc6BDqH2Uaphq9BtmVnXKfpRwN/DbIdmO3+07ZrkACwEBGQAJAgIAEgQUACQIKABIAEAQkACQISABIEJAAkCEgASBCQAJAgIAEgQUACQIKABIAEAQkAiUa+D7Kdzkn605zX1xfLelGv9tbtvm7s4r762dzZ7tXZkehtVkNz3dXvg/zczu3JXv2uwV7trVf7wkW9/N+I3i4Pp9gAkCAgASBRd0DuqHn/ZXq1t17tCxf18n8jersMtV6DBIBeVvcRJAD0rFoC0vZm28dsv217ex09ZGyftP2G7ddtT9bcy5jts7YPzVm2zPZe28eLx2vr7BGfxWw33EtfzHbXA9L2FZJ+IukbklZr5jeIV3e7jwpfi4i1PXDLwbikzZcs2y5pX0SskrSveI0ewGxflnH1wWzXcQS5TtLbEXEiIv4u6ZeSttTQR8+LiFclvXvJ4i2SdhbPd0q6t6tNoQyz3aB+me06AvIGSX+Z83q6WNYrQtLvbU/ZHq27mXkMRcRpSSoel9fcDy5itlvTc7Pd7Y8aSpLnWdZLf5W+PiJO2V4uaa/to8WfdkAVZnvA1HEEOS1pxZzXX5Z0qoY+5hURp4rHs5ImNHPa1EvO2B6WpOLxbM394CJmuzU9N9t1BORBSatsr7S9WNK3Je2poY/Psb3U9lWzzyVtknSofKuu2yNpa/F8q6SXauwFn8Vst6bnZrvrp9gRccH2I5J+J+kKSWMRcbjbfSSGJE3Ylmb+3fwiIn5bVzO2X5C0QdL1tqcl/UDSk5J+bfshSX+WdF9d/eGzmO3G9cts80kaAEjwSRoASBCQAJAgIAEgQUACQIKABIAEAQkACQISABIEJAAk/h/QV9BsToS1tgAAAABJRU5ErkJggg==\n", | |
"text/plain": [ | |
"<Figure size 432x288 with 2 Axes>" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
} | |
], | |
"source": [ | |
"from keras.datasets import mnist\n", | |
"import matplotlib.pyplot as plt\n", | |
"%matplotlib inline\n", | |
"\n", | |
"(X_train, y_train), (X_test, y_test) = mnist.load_data()\n", | |
"X_train = X_train / 255.\n", | |
"\n", | |
"X_train_a = X_train[:, :, :14].reshape((60000, 392))\n", | |
"X_train_b = X_train[:, :, 14:].reshape((60000, 392))\n", | |
"\n", | |
"plt.subplot(1, 2, 1)\n", | |
"plt.imshow(X_train_a[2].reshape((28, 14)), cmap='gray')\n", | |
"plt.subplot(1, 2, 2)\n", | |
"plt.imshow(X_train_b[2].reshape((28, 14)), cmap='gray')" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"0 4.27146275583903\n" | |
] | |
} | |
], | |
"source": [ | |
"n_epochs = 10\n", | |
"n_train_batches = len(X_train) // M\n", | |
"\n", | |
"with tf.Session(graph=tf.get_default_graph()) as sess:\n", | |
" sess.run(tf.global_variables_initializer())\n", | |
"\n", | |
" for epoch in range(n_epochs): \n", | |
" epoch_loss = 0\n", | |
" for i in range(n_train_batches):\n", | |
" j = i * M\n", | |
" X_train_batch_a = X_train_a[j:(j + M)]\n", | |
" X_train_batch_b = X_train_b[j:(j + M)]\n", | |
" res = sess.run([train_op, holy_loss], feed_dict={X_a: X_train_batch_a, X_b: X_train_batch_b})\n", | |
" epoch_loss += res[1]\n", | |
" print(epoch, epoch_loss / n_train_batches)" | |
] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.6.5" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment