Last active
December 13, 2018 18:23
-
-
Save NTT123/9823e517b691094c64472d0f2baf6ae1 to your computer and use it in GitHub Desktop.
Sonnet MNIST
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"nbformat": 4, | |
"nbformat_minor": 0, | |
"metadata": { | |
"colab": { | |
"name": "Sonnet MNIST", | |
"version": "0.3.2", | |
"provenance": [], | |
"collapsed_sections": [ | |
"gDIpKEqMFJsM", | |
"GxUq9wH5RTZL" | |
], | |
"include_colab_link": true | |
}, | |
"kernelspec": { | |
"name": "python3", | |
"display_name": "Python 3" | |
}, | |
"accelerator": "GPU" | |
}, | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "view-in-github", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"<a href=\"https://colab.research.google.com/gist/NTT123/9823e517b691094c64472d0f2baf6ae1/sonnet-mnist.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" | |
] | |
}, | |
{ | |
"metadata": { | |
"id": "542hNYow-N-S", | |
"colab_type": "text" | |
}, | |
"cell_type": "markdown", | |
"source": [ | |
"## Deepmind Sonnet + MNIST" | |
] | |
}, | |
{ | |
"metadata": { | |
"id": "HwfC79RJxqxf", | |
"colab_type": "toc" | |
}, | |
"cell_type": "markdown", | |
"source": [ | |
">>[Deepmind Sonnet + MNIST](#updateTitle=true&folderId=1vS4NPhLsF-s0dhVeKeGK54KqT7dstcOZ&scrollTo=542hNYow-N-S)\n", | |
"\n", | |
">>>[Model](#updateTitle=true&folderId=1vS4NPhLsF-s0dhVeKeGK54KqT7dstcOZ&scrollTo=nCn-Q0RsEjJV)\n", | |
"\n", | |
">>>[Tensorboard](#updateTitle=true&folderId=1vS4NPhLsF-s0dhVeKeGK54KqT7dstcOZ&scrollTo=ZJbtnfA2cq4M)\n", | |
"\n", | |
">>>[Dataloader](#updateTitle=true&folderId=1vS4NPhLsF-s0dhVeKeGK54KqT7dstcOZ&scrollTo=gDIpKEqMFJsM)\n", | |
"\n", | |
">>>[Define Loss and Accuracy](#updateTitle=true&folderId=1vS4NPhLsF-s0dhVeKeGK54KqT7dstcOZ&scrollTo=GxUq9wH5RTZL)\n", | |
"\n", | |
">>>[Training](#updateTitle=true&folderId=1vS4NPhLsF-s0dhVeKeGK54KqT7dstcOZ&scrollTo=2VTv6CcwRYx9)\n", | |
"\n" | |
] | |
}, | |
{ | |
"metadata": { | |
"id": "nCn-Q0RsEjJV", | |
"colab_type": "text" | |
}, | |
"cell_type": "markdown", | |
"source": [ | |
"### Model" | |
] | |
}, | |
{ | |
"metadata": { | |
"id": "DEX1uUX29-Jv", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"cell_type": "code", | |
"source": [ | |
"import tensorflow as tf\n", | |
"import sonnet as snt" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
}, | |
{ | |
"metadata": { | |
"id": "PLMy4zgWaAnF", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"cell_type": "code", | |
"source": [ | |
"class Module(snt.AbstractModule):\n", | |
" \n", | |
" def __init__(self, *args, **kwargs):\n", | |
" super().__init__( *args, **kwargs)\n", | |
" \n", | |
" self.is_training = True\n", | |
" \n", | |
" def eval(self):\n", | |
" self.is_training = False\n", | |
" \n", | |
" def train(self):\n", | |
" self.is_training = True" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
}, | |
{ | |
"metadata": { | |
"id": "9B8Rm6mX-Sd1", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"cell_type": "code", | |
"source": [ | |
"class MLP(Module):\n", | |
" def __init__(self):\n", | |
" super().__init__()\n", | |
" \n", | |
" with self._enter_variable_scope():\n", | |
" self.linear1 = snt.Linear(128)\n", | |
" self.linear2 = snt.Linear(256)\n", | |
" self.linear3 = snt.Linear(10)\n", | |
" \n", | |
" def _build(self, x):\n", | |
" x = tf.layers.flatten(x)\n", | |
" x = self.linear1(x)\n", | |
" x = tf.nn.relu(x)\n", | |
" x = tf.nn.dropout(x, 0.9) if self.is_training else x\n", | |
" x = self.linear2(x)\n", | |
" x = tf.nn.relu(x)\n", | |
" x = tf.nn.dropout(x, 0.9) if self.is_training else x\n", | |
" x = self.linear3(x)\n", | |
" return x" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
}, | |
{ | |
"metadata": { | |
"id": "ZJbtnfA2cq4M", | |
"colab_type": "text" | |
}, | |
"cell_type": "markdown", | |
"source": [ | |
"### Tensorboard" | |
] | |
}, | |
{ | |
"metadata": { | |
"id": "cUfT26OlBNTN", | |
"colab_type": "code", | |
"outputId": "6e80a2a4-a26b-4a18-bcd5-8252376224d7", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 102 | |
} | |
}, | |
"cell_type": "code", | |
"source": [ | |
"!pip install tensorboardcolab\n", | |
"from tensorboardcolab import *\n", | |
"\n", | |
"x = tf.placeholder(tf.float32, shape=(1, 28,28))\n", | |
"net = MLP()\n", | |
"y = net(x)\n", | |
"tbc=TensorBoardColab(graph_path=\"./log\")\n", | |
"filewriter = tf.summary.FileWriter(\"./log/1\", net.graph)" | |
], | |
"execution_count": 0, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"Requirement already satisfied: tensorboardcolab in /usr/local/lib/python3.6/dist-packages (0.0.20)\n" | |
], | |
"name": "stdout" | |
}, | |
{ | |
"output_type": "stream", | |
"text": [ | |
"Using TensorFlow backend.\n" | |
], | |
"name": "stderr" | |
}, | |
{ | |
"output_type": "stream", | |
"text": [ | |
"Wait for 8 seconds...\n", | |
"TensorBoard link:\n", | |
"http://3e5efe18.ngrok.io\n" | |
], | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"metadata": { | |
"id": "gDIpKEqMFJsM", | |
"colab_type": "text" | |
}, | |
"cell_type": "markdown", | |
"source": [ | |
"### Dataloader" | |
] | |
}, | |
{ | |
"metadata": { | |
"id": "nsB_-B0nBYoH", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"cell_type": "code", | |
"source": [ | |
"tf.reset_default_graph()" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
}, | |
{ | |
"metadata": { | |
"id": "pG2Zm8fgEXZP", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"cell_type": "code", | |
"source": [ | |
"mnist = tf.keras.datasets.mnist\n", | |
"(x_train, y_train),(x_test, y_test) = mnist.load_data()\n", | |
"x_train, x_test = x_train / 255.0, x_test / 255.0" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
}, | |
{ | |
"metadata": { | |
"id": "lWDLlcEKKQLB", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"cell_type": "code", | |
"source": [ | |
"mnist_ds = tf.data.Dataset.from_tensor_slices((x_train,y_train))\n", | |
"mnist_ds = mnist_ds.repeat().batch(32, drop_remainder=True).prefetch(2)\n", | |
"mnist_ds_test = tf.data.Dataset.from_tensor_slices((x_test,y_test))\n", | |
"mnist_ds_test = mnist_ds_test.repeat().batch(128, drop_remainder=True).prefetch(2)" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
}, | |
{ | |
"metadata": { | |
"id": "GxUq9wH5RTZL", | |
"colab_type": "text" | |
}, | |
"cell_type": "markdown", | |
"source": [ | |
"### Define Loss and Accuracy" | |
] | |
}, | |
{ | |
"metadata": { | |
"id": "sx0Ipn6FHLuV", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"cell_type": "code", | |
"source": [ | |
"it = mnist_ds.make_one_shot_iterator()\n", | |
"x, target = it.get_next()\n", | |
"x = tf.to_float(x)\n", | |
"target = tf.to_int32(target)\n", | |
"net = MLP()\n", | |
"y = net(x)" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
}, | |
{ | |
"metadata": { | |
"id": "woVQ5044H_n4", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"cell_type": "code", | |
"source": [ | |
"loss = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=target, logits=y)\n", | |
"loss = tf.reduce_mean(loss)\n", | |
"loss_summary = tf.summary.scalar(\"loss\", loss)" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
}, | |
{ | |
"metadata": { | |
"id": "X4UN_Y5sh8y5", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"cell_type": "code", | |
"source": [ | |
"it_test = mnist_ds_test.make_one_shot_iterator()\n", | |
"x, target = it_test.get_next()\n", | |
"x = tf.to_float(x)\n", | |
"target = tf.to_int32(target)\n", | |
"net.eval()\n", | |
"y = net(x)\n", | |
"net.train()\n", | |
"y = tf.argmax(y, axis=1, output_type=tf.int32)\n", | |
"acc, acc_ops = tf.metrics.accuracy(labels=target, predictions=y)\n", | |
"with tf.control_dependencies([acc, acc_ops]):\n", | |
" acc_summary = tf.summary.scalar(\"accuracy\", acc)" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
}, | |
{ | |
"metadata": { | |
"id": "2VTv6CcwRYx9", | |
"colab_type": "text" | |
}, | |
"cell_type": "markdown", | |
"source": [ | |
"### Training" | |
] | |
}, | |
{ | |
"metadata": { | |
"id": "aO9rP0u1IRtL", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"cell_type": "code", | |
"source": [ | |
"optimizer = tf.train.AdamOptimizer(1e-4)\n", | |
"train_ops = optimizer.minimize(loss, var_list=net.get_all_variables())" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
}, | |
{ | |
"metadata": { | |
"id": "Vt3n57d3Mgjc", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"cell_type": "code", | |
"source": [ | |
"sess = tf.Session()\n", | |
"sess.run(tf.global_variables_initializer())\n", | |
"sess.run(tf.local_variables_initializer())" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
}, | |
{ | |
"metadata": { | |
"id": "E5cFK5FWZCt8", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"cell_type": "code", | |
"source": [ | |
"step = 0" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
}, | |
{ | |
"metadata": { | |
"id": "399d4l7xMmVf", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"cell_type": "code", | |
"source": [ | |
"for step in range(step, 1_000_000):\n", | |
" _, ls = sess.run([train_ops, loss_summary])\n", | |
" filewriter.add_summary(ls, step)\n", | |
" \n", | |
" if step % 100 == 0:\n", | |
" accsum = sess.run(acc_summary)\n", | |
" filewriter.add_summary(accsum, step)" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
} | |
] | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment