Skip to content

Instantly share code, notes, and snippets.

@titu1994
Created April 24, 2018 02:16
Show Gist options
  • Save titu1994/2e8900ccada60e2d48a82915ae6a8538 to your computer and use it in GitHub Desktop.
Save titu1994/2e8900ccada60e2d48a82915ae6a8538 to your computer and use it in GitHub Desktop.
TF Eager JANet model
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"scrolled": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"WARNING:tensorflow:From D:\\Users\\Yue\\Anaconda3\\lib\\site-packages\\tensorflow\\contrib\\learn\\python\\learn\\datasets\\base.py:198: retry (from tensorflow.contrib.learn.python.learn.datasets.base) is deprecated and will be removed in a future version.\n",
"Instructions for updating:\n",
"Use the retry module or similar alternatives.\n"
]
}
],
"source": [
"import numpy as np\n",
"from collections import OrderedDict\n",
"import tensorflow as tf\n",
"from tensorflow.contrib.eager.python import tfe\n",
"\n",
"# This silent device placement will hurt performance on GPUs, but we are sticking to CPUs only so it\n",
"# should be fine\n",
"tf.enable_eager_execution(device_policy=tfe.DEVICE_PLACEMENT_SILENT)\n",
"tf.set_random_seed(0)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Goal\n",
"We are going to implement the JANet model from the paper [\"The unreasonable effectiveness of the forget gate\"](https://arxiv.org/abs/1804.04849) and try the addition experiment to see if it converges properly. \n",
"\n",
"This is not going to be a high performance model, nor will it probably match the performance of the paper. This is going to be a simple exercise in how easy it is to write down research models in Tensorflow Eager without all the baggage of Tensorflow graphs."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Components we need\n",
"For JANet, we need two components :\n",
"\n",
"1) Chrono Initializer from the paper [Can recurrent neural networks warp time? ](https://openreview.net/forum?id=SJcKhk-Ab)\n",
"\n",
"2) The RNN cell which will do most of the work\n",
"\n",
"Lets create the initializer first since it is simpler"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Chrono Initializer\n",
"In short, if we know the length of the sequence in the range $[T_{min}, T_{max}]$, then it would be beneficial to initialize the model with a forgetting time in the same range. Therefore, this would be equivalent to initialize the gate $g$ in the range $[\\frac{1}{T_{max}}, \\frac{1}{T_{min}}]$.\n",
"\n",
"If the values of both inputs and hidden layers are centered over time, $g(t)$ will typically take values centered around $\\sigma(b_g)$. To obtain values in the desired range above, the biases $b_g$ must be in the range $-log(T_{max} - 1)$ and $-log(T_{min} - 1)$.\n",
"\n",
"For the $t$th time step, the forget gate $f_t$ corresponds to $1 - g_t$, where as the input gate $i_t$ corresponds to $g_t$ in an LSTM RNN. Letting the minimum number of timesteps to be 1 and maximum to be $T_{max}$, the paper suggests initializing the forget gate and input gate as follows : \n",
"\n",
"\\begin{align}\n",
"b_f &\\sim log(U([1, T_{max}]) - 1) \\\\\\\n",
"b_i &= -b_f\n",
"\\end{align}\n",
"\n",
"For a complete explanation of ChronoInitializer, refer to the above paper.\n"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"class ChronoInitializer(tf.keras.initializers.RandomUniform):\n",
" \"\"\"\n",
" Chrono Initializer from the paper :\n",
" [Can recurrent neural networks warp time?](https://openreview.net/forum?id=SJcKhk-Ab)\n",
" \"\"\"\n",
"\n",
" def __init__(self, max_timesteps, seed=None):\n",
" super(ChronoInitializer, self).__init__(1., max_timesteps - 1, seed)\n",
" self.max_timesteps = max_timesteps\n",
"\n",
" def __call__(self, shape, dtype=None, partition_info=None):\n",
" values = super(ChronoInitializer, self).__call__(shape, dtype=dtype, partition_info=partition_info)\n",
" return tf.log(values)\n",
"\n",
" def get_config(self):\n",
" config = {\n",
" 'max_timesteps': self.max_timesteps\n",
" }\n",
" base_config = super(ChronoInitializer, self).get_config()\n",
" return dict(list(base_config.items()) + list(config.items()))\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# RNN Cell\n",
"\n",
"The model described in the paper [\"The unreasonable effectiveness of the forget gate\"](https://arxiv.org/abs/1804.04849) is called JANet - a model where the forget $(g_f)$ gate alone is the one of two gates, alongside with the memory $(g_c)$ gate. It doesn't use the input $(g_i)$ and output $(g_o)$ gates."
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"\n",
"class JANETModel(tf.keras.Model):\n",
"\n",
" def __init__(self, units, num_outputs, num_timesteps, output_activation='sigmoid', **kwargs):\n",
" super(JANETModel, self).__init__(**kwargs)\n",
"\n",
" self.units = units\n",
" self.classes = num_outputs\n",
" self.num_timesteps = num_timesteps\n",
"\n",
" # Initialize the forget gate with ChronoInitializer\n",
" # The memory gate is initialized with zeros\n",
" def bias_initializer(_, *args, **kwargs):\n",
" forget_gate = ChronoInitializer(self.num_timesteps)((self.units,), *args, **kwargs)\n",
"\n",
" return tf.keras.backend.concatenate([\n",
" forget_gate,\n",
" tf.keras.initializers.Zeros()((self.units,), *args, **kwargs),\n",
" ])\n",
" \n",
" # Initialize the hidden and recurrent gates\n",
" self.kernel = tf.keras.layers.Dense(2 * units, use_bias=False,\n",
" kernel_initializer='glorot_uniform')\n",
" \n",
" self.recurrent_kernel = tf.keras.layers.Dense(2 * units,\n",
" kernel_initializer='glorot_uniform',\n",
" bias_initializer=bias_initializer)\n",
" \n",
" # Initialize the final layer (for classification or regression depending on the output activation)\n",
" self.output_dense = tf.keras.layers.Dense(num_outputs, activation=output_activation)\n",
"\n",
" \n",
" \"\"\"\n",
" This is required to override a certain issue with how the weights of a model are loaded\n",
" by a checkpoint.\n",
" \n",
" While we could use `model.call(inputs)` directly, the more pythonic way of doing this is\n",
" by using `model(inputs)`.\n",
" \n",
" However, in Eager mode, the ordinary function call does *not* forward the method to the\n",
" models `call` method. This bypasses that issue.\n",
" \"\"\"\n",
" def __call__(self, *args, **kwargs):\n",
" if not tf.executing_eagerly():\n",
" super(JANETModel, self).__call__(*args, **kwargs)\n",
" return self.call(*args, **kwargs)\n",
"\n",
"\n",
" def call(self, inputs, training=None, mask=None):\n",
" # Initialize the hidden memory states\n",
" outputs = []\n",
" states = []\n",
" h_state = tf.zeros((inputs.shape[0], self.units))\n",
" c_state = tf.zeros((inputs.shape[0], self.units))\n",
" \n",
" # Input is in the shape [None, timesteps, input_dim]\n",
" for t in range(inputs.shape[1]):\n",
" ip = inputs[:, t, :] # access the T'th timestep\n",
" \n",
" # Perform the forward pass of the model\n",
" z = self.kernel(ip)\n",
" z += self.recurrent_kernel(h_state)\n",
" \n",
" # Split the output into the forget and memory outputs\n",
" z0 = z[:, :self.units]\n",
" z1 = z[:, self.units: 2 * self.units]\n",
"\n",
" # gate updates\n",
" f = tf.keras.activations.sigmoid(z0)\n",
" c = f * c_state + (1. - f) * tf.nn.tanh(z1)\n",
"\n",
" # state updates\n",
" h = c\n",
" \n",
" # update our previous state\n",
" h_state = h\n",
" c_state = c\n",
" \n",
" # preserve the history of our states\n",
" outputs.append(h)\n",
" states.append([h]) # here we maintain just 1 state variable, can be more\n",
" \n",
" # save memory states to global variables so we can access them later\n",
" self.cell_outputs = tf.stack(outputs, axis=1)\n",
" self.cell_states = states\n",
" \n",
" # perform forward pass of the classifier / regressor\n",
" preds = self.output_dense(outputs[-1])\n",
"\n",
" return preds"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Setup\n",
"In Eager, if you use only tf.keras layers, then training is a lot simpler. You can simply call the .fit and .predict methods after compiling the models. This is what we will not do for the current model, even though we could, as it is important to explain the limitations of sticking to models with only Keras layers.\n",
"\n",
"We will declare some helper methods, along with the dataset loader. Since the addition task is simple enough, we will not be using the canonical way to load data - tf.data pipeline. For this dataset in which we generate synthetic data, the overhead of the tf.data pipeline is more than the benefit."
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {},
"outputs": [],
"source": [
"# we define some constants \n",
"import os\n",
"if not os.path.exists('weights'):\n",
" os.makedirs('weights/')\n",
"\n",
"# Parameters taken from https://arxiv.org/abs/1804.04849\n",
"TIME_STEPS = 100\n",
"NUM_UNITS = 128\n",
"LEARNING_RATE = 0.001\n",
"STEPS_PER_EPOCH = 100\n",
"NUM_EPOCHS = 10\n",
"BATCH_SIZE = 50\n",
"\n",
"CHECKPOINTS_DIR = 'checkpoints_addition/'\n",
"CHECKPOINT_PATH = CHECKPOINTS_DIR + 'addition_janet.ckpt'\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Addition Dataset Generator"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"# Code reused from https://github.com/batzner/indrnn/blob/master/examples/addition_rnn.py\n",
"# with some modifications to run on Tensorflow Eager mode\n",
"def batch_generator():\n",
" while True:\n",
" \"\"\"Generate the adding problem dataset\"\"\"\n",
" # Build the first sequence\n",
" add_values = np.random.rand(BATCH_SIZE, TIME_STEPS)\n",
"\n",
" # Build the second sequence with one 1 in each half and 0s otherwise\n",
" add_indices = np.zeros_like(add_values, dtype='float32')\n",
" half = int(TIME_STEPS / 2)\n",
" for i in range(BATCH_SIZE):\n",
" first_half = np.random.randint(half)\n",
" second_half = np.random.randint(half, TIME_STEPS)\n",
" add_indices[i, [first_half, second_half]] = 1.\n",
"\n",
" # Zip the values and indices in a third dimension:\n",
" # inputs has the shape (batch_size, time_steps, 2)\n",
" inputs = np.dstack((add_values, add_indices))\n",
" targets = np.sum(np.multiply(add_values, add_indices), axis=1)\n",
" targets = np.expand_dims(targets, -1)\n",
"\n",
" # center at zero mean\n",
" inputs -= np.mean(inputs, axis=0, keepdims=True)\n",
"\n",
" inputs = tf.convert_to_tensor(inputs, dtype=tf.float32)\n",
" targets = tf.convert_to_tensor(targets, dtype=tf.float32)\n",
"\n",
" yield inputs, targets\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Loss and Gradients\n",
"The following functions simply compute the mean squared error, and compute the gradients of the model."
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"def loss(y_true, y_pred):\n",
" return tf.losses.mean_squared_error(y_true, y_pred)\n",
"\n",
"\n",
"def grad(model, X, y):\n",
" with tfe.GradientTape() as tape:\n",
" preds = model(X)\n",
" loss_val = loss(y, preds)\n",
"\n",
" vars = model.trainable_variables \n",
" grads = tape.gradient(loss_val, vars)\n",
" grad_vars = zip(grads, vars)\n",
"\n",
" return grad_vars, loss_val"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Training and Checkpointing"
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {
"collapsed": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch # : 1\n",
"1 : 1.1137999\n",
"2 : 1.3734939\n",
"3 : 1.2632025\n",
"4 : 1.0217344\n",
"5 : 1.1007522\n",
"6 : 1.0640762\n",
"7 : 0.77063787\n",
"8 : 0.84670275\n",
"9 : 0.79458064\n",
"10 : 0.7586803\n",
"11 : 0.5092878\n",
"12 : 0.42632163\n",
"13 : 0.24331588\n",
"14 : 0.17305855\n",
"15 : 0.22396189\n",
"16 : 0.46569645\n",
"17 : 0.30956933\n",
"18 : 0.19394228\n",
"19 : 0.18876888\n",
"20 : 0.14756978\n",
"21 : 0.1357087\n",
"22 : 0.24186754\n",
"23 : 0.21103485\n",
"24 : 0.16359738\n",
"25 : 0.29180482\n",
"26 : 0.30941334\n",
"27 : 0.18169452\n",
"28 : 0.17933026\n",
"29 : 0.19284649\n",
"30 : 0.19533329\n",
"31 : 0.17026216\n",
"32 : 0.2019141\n",
"33 : 0.16495672\n",
"34 : 0.18687668\n",
"35 : 0.18738216\n",
"36 : 0.16143692\n",
"37 : 0.21796392\n",
"38 : 0.20093755\n",
"39 : 0.1940393\n",
"40 : 0.13964733\n",
"41 : 0.14817266\n",
"42 : 0.20116977\n",
"43 : 0.19677643\n",
"44 : 0.18117912\n",
"45 : 0.1921833\n",
"46 : 0.15685205\n",
"47 : 0.16924348\n",
"48 : 0.13538565\n",
"49 : 0.13495164\n",
"50 : 0.1701503\n",
"51 : 0.1312775\n",
"52 : 0.15106285\n",
"53 : 0.14762932\n",
"54 : 0.11621122\n",
"55 : 0.16249685\n",
"56 : 0.18139866\n",
"57 : 0.1195894\n",
"58 : 0.19839394\n",
"59 : 0.18842766\n",
"60 : 0.15654941\n",
"61 : 0.21005692\n",
"62 : 0.13576888\n",
"63 : 0.1602078\n",
"64 : 0.20151004\n",
"65 : 0.12443797\n",
"66 : 0.16430213\n",
"67 : 0.1604485\n",
"68 : 0.12348908\n",
"69 : 0.16183636\n",
"70 : 0.2118262\n",
"71 : 0.1975509\n",
"72 : 0.18941139\n",
"73 : 0.2185139\n",
"74 : 0.15553287\n",
"75 : 0.15934338\n",
"76 : 0.1537831\n",
"77 : 0.1566895\n",
"78 : 0.12518919\n",
"79 : 0.19793692\n",
"80 : 0.1583711\n",
"81 : 0.2284777\n",
"82 : 0.17029156\n",
"83 : 0.18526001\n",
"84 : 0.17859228\n",
"85 : 0.15501665\n",
"86 : 0.20986289\n",
"87 : 0.15498531\n",
"88 : 0.13561767\n",
"89 : 0.15113877\n",
"90 : 0.17337379\n",
"91 : 0.15633701\n",
"92 : 0.18732792\n",
"93 : 0.18975239\n",
"94 : 0.1320506\n",
"95 : 0.15652293\n",
"96 : 0.13723686\n",
"97 : 0.2221598\n",
"98 : 0.16244914\n",
"99 : 0.1635633\n",
"100 : 0.115488395\n",
"\n",
"Saving weights\n",
"\n",
"Epoch # : 2\n",
"101 : 0.15142712\n",
"102 : 0.19490778\n",
"103 : 0.15863025\n",
"104 : 0.16891563\n",
"105 : 0.1488823\n",
"106 : 0.19002002\n",
"107 : 0.13309555\n",
"108 : 0.1675126\n",
"109 : 0.11976473\n",
"110 : 0.13414559\n",
"111 : 0.17699897\n",
"112 : 0.15871114\n",
"113 : 0.18262321\n",
"114 : 0.15568721\n",
"115 : 0.14995322\n",
"116 : 0.22270572\n",
"117 : 0.18467392\n",
"118 : 0.19472942\n",
"119 : 0.13939682\n",
"120 : 0.15584089\n",
"121 : 0.15191117\n",
"122 : 0.20868172\n",
"123 : 0.19065133\n",
"124 : 0.20672204\n",
"125 : 0.15328361\n",
"126 : 0.15681319\n",
"127 : 0.17248166\n",
"128 : 0.15680744\n",
"129 : 0.1725671\n",
"130 : 0.15811142\n",
"131 : 0.1680556\n",
"132 : 0.14225216\n",
"133 : 0.14263257\n",
"134 : 0.14152426\n",
"135 : 0.12281683\n",
"136 : 0.15506937\n",
"137 : 0.25925797\n",
"138 : 0.12942582\n",
"139 : 0.17984153\n",
"140 : 0.15990013\n",
"141 : 0.103221685\n",
"142 : 0.14273256\n",
"143 : 0.16526337\n",
"144 : 0.19970125\n",
"145 : 0.1602799\n",
"146 : 0.18229164\n",
"147 : 0.1482605\n",
"148 : 0.16955715\n",
"149 : 0.15879424\n",
"150 : 0.15778762\n",
"151 : 0.18734623\n",
"152 : 0.10346366\n",
"153 : 0.17770061\n",
"154 : 0.13422662\n",
"155 : 0.0950175\n",
"156 : 0.15985726\n",
"157 : 0.16109079\n",
"158 : 0.19612525\n",
"159 : 0.26874956\n",
"160 : 0.10374393\n",
"161 : 0.14338039\n",
"162 : 0.17985679\n",
"163 : 0.18066671\n",
"164 : 0.11624183\n",
"165 : 0.15065072\n",
"166 : 0.1673728\n",
"167 : 0.17990631\n",
"168 : 0.17882518\n",
"169 : 0.19348824\n",
"170 : 0.1253838\n",
"171 : 0.20440447\n",
"172 : 0.17718393\n",
"173 : 0.12889002\n",
"174 : 0.20043197\n",
"175 : 0.12935922\n",
"176 : 0.16156515\n",
"177 : 0.1473184\n",
"178 : 0.20457672\n",
"179 : 0.19170585\n",
"180 : 0.12670062\n",
"181 : 0.11376371\n",
"182 : 0.12854005\n",
"183 : 0.1881057\n",
"184 : 0.19671673\n",
"185 : 0.1915549\n",
"186 : 0.123074725\n",
"187 : 0.14882633\n",
"188 : 0.15587412\n",
"189 : 0.1549204\n",
"190 : 0.15998723\n",
"191 : 0.15340252\n",
"192 : 0.13852304\n",
"193 : 0.1785473\n",
"194 : 0.18779467\n",
"195 : 0.14034723\n",
"196 : 0.1458752\n",
"197 : 0.16056164\n",
"198 : 0.13652371\n",
"199 : 0.19065769\n",
"200 : 0.20847234\n",
"\n",
"Saving weights\n",
"\n",
"Epoch # : 3\n",
"201 : 0.14631604\n",
"202 : 0.19127764\n",
"203 : 0.1114826\n",
"204 : 0.14693926\n",
"205 : 0.18180187\n",
"206 : 0.1757907\n",
"207 : 0.16280074\n",
"208 : 0.22122839\n",
"209 : 0.1321277\n",
"210 : 0.14741613\n",
"211 : 0.16193426\n",
"212 : 0.1714394\n",
"213 : 0.20115525\n",
"214 : 0.19125421\n",
"215 : 0.19023961\n",
"216 : 0.22137111\n",
"217 : 0.17897464\n",
"218 : 0.1441868\n",
"219 : 0.19124608\n",
"220 : 0.16174869\n",
"221 : 0.17805603\n",
"222 : 0.16275558\n",
"223 : 0.15547869\n",
"224 : 0.18385968\n",
"225 : 0.16811806\n",
"226 : 0.14452225\n",
"227 : 0.14193898\n",
"228 : 0.1572187\n",
"229 : 0.15770416\n",
"230 : 0.18410131\n",
"231 : 0.171729\n",
"232 : 0.113374025\n",
"233 : 0.17020687\n",
"234 : 0.21779545\n",
"235 : 0.13953048\n",
"236 : 0.13485359\n",
"237 : 0.11338455\n",
"238 : 0.1580136\n",
"239 : 0.15234464\n",
"240 : 0.16268384\n",
"241 : 0.18441494\n",
"242 : 0.14589624\n",
"243 : 0.13381128\n",
"244 : 0.14633805\n",
"245 : 0.14886004\n",
"246 : 0.16279405\n",
"247 : 0.19427933\n",
"248 : 0.16472708\n",
"249 : 0.15572958\n",
"250 : 0.14369208\n",
"251 : 0.20090906\n",
"252 : 0.14735584\n",
"253 : 0.16509303\n",
"254 : 0.15320225\n",
"255 : 0.16610529\n",
"256 : 0.15259662\n",
"257 : 0.15674606\n",
"258 : 0.16167414\n",
"259 : 0.15840422\n",
"260 : 0.17905697\n",
"261 : 0.16085684\n",
"262 : 0.1844318\n",
"263 : 0.17283636\n",
"264 : 0.18175147\n",
"265 : 0.19120647\n",
"266 : 0.11100922\n",
"267 : 0.16280647\n",
"268 : 0.2003941\n",
"269 : 0.16432606\n",
"270 : 0.16657645\n",
"271 : 0.12586235\n",
"272 : 0.1741056\n",
"273 : 0.15290031\n",
"274 : 0.18318154\n",
"275 : 0.15236418\n",
"276 : 0.12173611\n",
"277 : 0.16374679\n",
"278 : 0.17364188\n",
"279 : 0.15512083\n",
"280 : 0.13950339\n",
"281 : 0.16697197\n",
"282 : 0.17575347\n",
"283 : 0.10168448\n",
"284 : 0.15351738\n",
"285 : 0.1746062\n",
"286 : 0.20178139\n",
"287 : 0.13283373\n",
"288 : 0.18089834\n",
"289 : 0.15519604\n",
"290 : 0.14023435\n",
"291 : 0.12693751\n",
"292 : 0.1658656\n",
"293 : 0.13202791\n",
"294 : 0.14162855\n",
"295 : 0.2071677\n",
"296 : 0.19718117\n",
"297 : 0.17673382\n",
"298 : 0.16001706\n",
"299 : 0.14818034\n",
"300 : 0.17414133\n",
"\n",
"Saving weights\n",
"\n",
"Epoch # : 4\n",
"301 : 0.15989183\n",
"302 : 0.1683334\n",
"303 : 0.1634808\n",
"304 : 0.2368403\n",
"305 : 0.15235831\n",
"306 : 0.20588325\n",
"307 : 0.15767582\n",
"308 : 0.20982075\n",
"309 : 0.13626595\n",
"310 : 0.1461687\n",
"311 : 0.17249995\n",
"312 : 0.1621354\n",
"313 : 0.15050727\n",
"314 : 0.1332962\n",
"315 : 0.2033216\n",
"316 : 0.21490146\n",
"317 : 0.14389195\n",
"318 : 0.17702124\n",
"319 : 0.13090472\n",
"320 : 0.15642883\n",
"321 : 0.14335191\n",
"322 : 0.15122734\n",
"323 : 0.1517923\n",
"324 : 0.14177383\n",
"325 : 0.12271037\n",
"326 : 0.15354306\n",
"327 : 0.14585589\n",
"328 : 0.24630192\n",
"329 : 0.14025758\n",
"330 : 0.16827598\n",
"331 : 0.11754834\n",
"332 : 0.18819757\n",
"333 : 0.10686378\n",
"334 : 0.22321922\n",
"335 : 0.14689402\n",
"336 : 0.15480909\n",
"337 : 0.15470141\n",
"338 : 0.15753229\n",
"339 : 0.17729716\n",
"340 : 0.17660692\n",
"341 : 0.14480473\n",
"342 : 0.14537503\n",
"343 : 0.19240357\n",
"344 : 0.14784722\n",
"345 : 0.18822044\n",
"346 : 0.1150868\n",
"347 : 0.17408083\n",
"348 : 0.12811284\n",
"349 : 0.1476301\n",
"350 : 0.18280677\n",
"351 : 0.17109857\n",
"352 : 0.10401657\n",
"353 : 0.18177532\n",
"354 : 0.16407819\n",
"355 : 0.1328129\n",
"356 : 0.12777561\n",
"357 : 0.16972435\n",
"358 : 0.121474735\n",
"359 : 0.1802315\n",
"360 : 0.17138006\n",
"361 : 0.14497444\n",
"362 : 0.1328586\n",
"363 : 0.12166094\n",
"364 : 0.14256704\n",
"365 : 0.13014755\n",
"366 : 0.15340823\n",
"367 : 0.12284045\n",
"368 : 0.17081615\n",
"369 : 0.09157196\n",
"370 : 0.14261661\n",
"371 : 0.15789641\n",
"372 : 0.14877114\n",
"373 : 0.19000767\n",
"374 : 0.18016705\n",
"375 : 0.13460127\n",
"376 : 0.1489218\n",
"377 : 0.15343726\n",
"378 : 0.12069669\n",
"379 : 0.16354252\n",
"380 : 0.1271183\n",
"381 : 0.15460804\n",
"382 : 0.1431959\n",
"383 : 0.17617328\n",
"384 : 0.17155463\n",
"385 : 0.16052125\n",
"386 : 0.20818532\n",
"387 : 0.22987455\n",
"388 : 0.12673472\n",
"389 : 0.116720706\n",
"390 : 0.12031011\n",
"391 : 0.10456814\n",
"392 : 0.1515617\n",
"393 : 0.110967904\n",
"394 : 0.14067063\n",
"395 : 0.10152199\n",
"396 : 0.14363636\n",
"397 : 0.109732\n",
"398 : 0.15370482\n",
"399 : 0.23777321\n",
"400 : 0.17698498\n",
"\n",
"Saving weights\n",
"\n",
"Epoch # : 5\n",
"401 : 0.14186129\n",
"402 : 0.16293369\n",
"403 : 0.12690492\n",
"404 : 0.12299286\n",
"405 : 0.14145467\n",
"406 : 0.16558231\n",
"407 : 0.18628284\n",
"408 : 0.13790326\n",
"409 : 0.15945606\n",
"410 : 0.15455613\n",
"411 : 0.18341199\n",
"412 : 0.14914699\n",
"413 : 0.15578014\n",
"414 : 0.12761003\n",
"415 : 0.14642216\n",
"416 : 0.19977523\n",
"417 : 0.11384835\n",
"418 : 0.11266821\n",
"419 : 0.16545223\n",
"420 : 0.18985954\n",
"421 : 0.14535917\n",
"422 : 0.11091176\n",
"423 : 0.12428767\n",
"424 : 0.16970477\n",
"425 : 0.106244504\n",
"426 : 0.16632122\n",
"427 : 0.12626189\n",
"428 : 0.12667525\n",
"429 : 0.16489224\n",
"430 : 0.17477669\n",
"431 : 0.12704808\n",
"432 : 0.14285174\n",
"433 : 0.1492689\n",
"434 : 0.12137743\n",
"435 : 0.1854474\n",
"436 : 0.14642225\n",
"437 : 0.15149617\n",
"438 : 0.16470455\n",
"439 : 0.13866177\n",
"440 : 0.15635641\n",
"441 : 0.15837139\n",
"442 : 0.13060589\n",
"443 : 0.15263985\n",
"444 : 0.12847354\n",
"445 : 0.14056405\n",
"446 : 0.1599973\n",
"447 : 0.1268467\n",
"448 : 0.15149932\n",
"449 : 0.15069152\n",
"450 : 0.095344715\n",
"451 : 0.12692823\n",
"452 : 0.16134982\n",
"453 : 0.10846744\n",
"454 : 0.110591486\n",
"455 : 0.11525818\n",
"456 : 0.12970857\n",
"457 : 0.09025608\n",
"458 : 0.13471705\n",
"459 : 0.1125316\n",
"460 : 0.17845346\n",
"461 : 0.13616502\n",
"462 : 0.13095608\n",
"463 : 0.11662094\n",
"464 : 0.13172145\n",
"465 : 0.113783605\n",
"466 : 0.1468534\n",
"467 : 0.09930524\n",
"468 : 0.123042375\n",
"469 : 0.11898151\n",
"470 : 0.09886933\n",
"471 : 0.10984399\n",
"472 : 0.123674296\n",
"473 : 0.15884851\n",
"474 : 0.09473106\n",
"475 : 0.13640635\n",
"476 : 0.14475779\n",
"477 : 0.09994569\n",
"478 : 0.13259643\n",
"479 : 0.15337387\n",
"480 : 0.10862705\n",
"481 : 0.09243172\n",
"482 : 0.12102955\n",
"483 : 0.12756248\n",
"484 : 0.06705673\n",
"485 : 0.112004735\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"486 : 0.081523485\n",
"487 : 0.13786794\n",
"488 : 0.09381305\n",
"489 : 0.11110353\n",
"490 : 0.098182514\n",
"491 : 0.09387477\n",
"492 : 0.09438832\n",
"493 : 0.095476516\n",
"494 : 0.06817464\n",
"495 : 0.09543871\n",
"496 : 0.07350923\n",
"497 : 0.11202486\n",
"498 : 0.059394952\n",
"499 : 0.085059434\n",
"500 : 0.052376606\n",
"\n",
"Saving weights\n",
"\n",
"Epoch # : 6\n",
"501 : 0.07366293\n",
"502 : 0.0977461\n",
"503 : 0.07575514\n",
"504 : 0.08485856\n",
"505 : 0.09932089\n",
"506 : 0.09266989\n",
"507 : 0.06556279\n",
"508 : 0.08859516\n",
"509 : 0.06656102\n",
"510 : 0.074160814\n",
"511 : 0.06465433\n",
"512 : 0.0768479\n",
"513 : 0.061056625\n",
"514 : 0.08437074\n",
"515 : 0.0702509\n",
"516 : 0.12187559\n",
"517 : 0.083756074\n",
"518 : 0.10581833\n",
"519 : 0.08884242\n",
"520 : 0.06761867\n",
"521 : 0.082420155\n",
"522 : 0.08711414\n",
"523 : 0.095563926\n",
"524 : 0.06527415\n",
"525 : 0.05995813\n",
"526 : 0.06321092\n",
"527 : 0.06498899\n",
"528 : 0.08993475\n",
"529 : 0.056264304\n",
"530 : 0.07335807\n",
"531 : 0.06862191\n",
"532 : 0.06770751\n",
"533 : 0.079412185\n",
"534 : 0.06537493\n",
"535 : 0.06720868\n",
"536 : 0.05963801\n",
"537 : 0.09473534\n",
"538 : 0.09055745\n",
"539 : 0.04787309\n",
"540 : 0.05338252\n",
"541 : 0.06361542\n",
"542 : 0.070608474\n",
"543 : 0.045933895\n",
"544 : 0.11329553\n",
"545 : 0.064792864\n",
"546 : 0.069107845\n",
"547 : 0.0710126\n",
"548 : 0.08020279\n",
"549 : 0.078129314\n",
"550 : 0.06880879\n",
"551 : 0.06652519\n",
"552 : 0.048686218\n",
"553 : 0.042917076\n",
"554 : 0.074027404\n",
"555 : 0.04995659\n",
"556 : 0.099824965\n",
"557 : 0.076344736\n",
"558 : 0.06369189\n",
"559 : 0.050895676\n",
"560 : 0.07711893\n",
"561 : 0.06045536\n",
"562 : 0.07851929\n",
"563 : 0.08319768\n",
"564 : 0.05253439\n",
"565 : 0.05752097\n",
"566 : 0.08168973\n",
"567 : 0.06884125\n",
"568 : 0.055582065\n",
"569 : 0.08008318\n",
"570 : 0.07677271\n",
"571 : 0.05553363\n",
"572 : 0.059638027\n",
"573 : 0.07149445\n",
"574 : 0.064126\n",
"575 : 0.060046457\n",
"576 : 0.047065362\n",
"577 : 0.06605486\n",
"578 : 0.045319017\n",
"579 : 0.07018263\n",
"580 : 0.041880213\n",
"581 : 0.06465956\n",
"582 : 0.06628401\n",
"583 : 0.059362903\n",
"584 : 0.057434283\n",
"585 : 0.041649923\n",
"586 : 0.049649864\n",
"587 : 0.05602722\n",
"588 : 0.051012743\n",
"589 : 0.063672274\n",
"590 : 0.060314998\n",
"591 : 0.051572323\n",
"592 : 0.05738103\n",
"593 : 0.049807936\n",
"594 : 0.05839544\n",
"595 : 0.060877696\n",
"596 : 0.04523378\n",
"597 : 0.07385907\n",
"598 : 0.05840019\n",
"599 : 0.044012066\n",
"600 : 0.048904534\n",
"\n",
"Saving weights\n",
"\n",
"Epoch # : 7\n",
"601 : 0.04777501\n",
"602 : 0.045295812\n",
"603 : 0.054420196\n",
"604 : 0.036967106\n",
"605 : 0.051850665\n",
"606 : 0.050646733\n",
"607 : 0.042383656\n",
"608 : 0.042954408\n",
"609 : 0.04449568\n",
"610 : 0.043360166\n",
"611 : 0.052302714\n",
"612 : 0.051783744\n",
"613 : 0.043744653\n",
"614 : 0.042835366\n",
"615 : 0.042453796\n",
"616 : 0.04722375\n",
"617 : 0.06010188\n",
"618 : 0.05015109\n",
"619 : 0.07052367\n",
"620 : 0.06358686\n",
"621 : 0.054715637\n",
"622 : 0.043309633\n",
"623 : 0.038239118\n",
"624 : 0.04742624\n",
"625 : 0.031015906\n",
"626 : 0.039885774\n",
"627 : 0.04358694\n",
"628 : 0.037214097\n",
"629 : 0.041405972\n",
"630 : 0.043837957\n",
"631 : 0.037334636\n",
"632 : 0.037288006\n",
"633 : 0.038870875\n",
"634 : 0.04596491\n",
"635 : 0.03999333\n",
"636 : 0.031630196\n",
"637 : 0.032454945\n",
"638 : 0.042408675\n",
"639 : 0.037228826\n",
"640 : 0.03534526\n",
"641 : 0.03495513\n",
"642 : 0.038005464\n",
"643 : 0.05133914\n",
"644 : 0.031021714\n",
"645 : 0.06480448\n",
"646 : 0.055641443\n",
"647 : 0.033429805\n",
"648 : 0.07204032\n",
"649 : 0.035686243\n",
"650 : 0.041231398\n",
"651 : 0.04310468\n",
"652 : 0.05581066\n",
"653 : 0.028777342\n",
"654 : 0.027452188\n",
"655 : 0.041834883\n",
"656 : 0.04338586\n",
"657 : 0.047400884\n",
"658 : 0.034692727\n",
"659 : 0.029945832\n",
"660 : 0.04812961\n",
"661 : 0.031002436\n",
"662 : 0.03977715\n",
"663 : 0.03498343\n",
"664 : 0.039574455\n",
"665 : 0.022422504\n",
"666 : 0.041366577\n",
"667 : 0.018463558\n",
"668 : 0.043331485\n",
"669 : 0.034572016\n",
"670 : 0.025236774\n",
"671 : 0.042498186\n",
"672 : 0.026631989\n",
"673 : 0.033520643\n",
"674 : 0.03707531\n",
"675 : 0.022538457\n",
"676 : 0.027253294\n",
"677 : 0.037255332\n",
"678 : 0.026199946\n",
"679 : 0.032597672\n",
"680 : 0.025695058\n",
"681 : 0.02061349\n",
"682 : 0.03616573\n",
"683 : 0.025639756\n",
"684 : 0.04018739\n",
"685 : 0.034272414\n",
"686 : 0.01547039\n",
"687 : 0.03216763\n",
"688 : 0.036220647\n",
"689 : 0.024898035\n",
"690 : 0.03058352\n",
"691 : 0.025642531\n",
"692 : 0.024973879\n",
"693 : 0.014782318\n",
"694 : 0.02827166\n",
"695 : 0.016179822\n",
"696 : 0.02347338\n",
"697 : 0.023594905\n",
"698 : 0.021632273\n",
"699 : 0.022306101\n",
"700 : 0.016289122\n",
"\n",
"Saving weights\n",
"\n",
"Epoch # : 8\n",
"701 : 0.020902649\n",
"702 : 0.025259122\n",
"703 : 0.025379717\n",
"704 : 0.022688245\n",
"705 : 0.016392633\n",
"706 : 0.01776363\n",
"707 : 0.017607488\n",
"708 : 0.033347383\n",
"709 : 0.021862421\n",
"710 : 0.015916483\n",
"711 : 0.030695878\n",
"712 : 0.031233415\n",
"713 : 0.016633924\n",
"714 : 0.033028144\n",
"715 : 0.021118999\n",
"716 : 0.023250546\n",
"717 : 0.02820932\n",
"718 : 0.017477375\n",
"719 : 0.021663867\n",
"720 : 0.03238818\n",
"721 : 0.018695714\n",
"722 : 0.025090491\n",
"723 : 0.022978777\n",
"724 : 0.025241503\n",
"725 : 0.023944862\n",
"726 : 0.019659953\n",
"727 : 0.02097785\n",
"728 : 0.01966372\n",
"729 : 0.023441292\n",
"730 : 0.020550003\n",
"731 : 0.018677153\n",
"732 : 0.020555295\n",
"733 : 0.0153756095\n",
"734 : 0.025121372\n",
"735 : 0.01652127\n",
"736 : 0.0220072\n",
"737 : 0.00999939\n",
"738 : 0.025371652\n",
"739 : 0.01686145\n",
"740 : 0.024796648\n",
"741 : 0.014960512\n",
"742 : 0.016903607\n",
"743 : 0.021749733\n",
"744 : 0.019653985\n",
"745 : 0.016175628\n",
"746 : 0.01462836\n",
"747 : 0.02161517\n",
"748 : 0.020391822\n",
"749 : 0.037578925\n",
"750 : 0.022586687\n",
"751 : 0.0155431675\n",
"752 : 0.016204944\n",
"753 : 0.016131276\n",
"754 : 0.019290337\n",
"755 : 0.025681803\n",
"756 : 0.023121621\n",
"757 : 0.028586132\n",
"758 : 0.03171596\n",
"759 : 0.019623036\n",
"760 : 0.031834494\n",
"761 : 0.023419132\n",
"762 : 0.015691467\n",
"763 : 0.033500258\n",
"764 : 0.022730062\n",
"765 : 0.019922173\n",
"766 : 0.032994896\n",
"767 : 0.014091652\n",
"768 : 0.027716422\n",
"769 : 0.026400797\n",
"770 : 0.012892388\n",
"771 : 0.032470938\n",
"772 : 0.01726125\n",
"773 : 0.020152602\n",
"774 : 0.022208914\n",
"775 : 0.014932071\n",
"776 : 0.012937251\n",
"777 : 0.015766947\n",
"778 : 0.0132568255\n",
"779 : 0.0155861\n",
"780 : 0.017096061\n",
"781 : 0.013463872\n",
"782 : 0.018087575\n",
"783 : 0.024214478\n",
"784 : 0.010821706\n",
"785 : 0.010175236\n",
"786 : 0.026655247\n",
"787 : 0.011721212\n",
"788 : 0.02163083\n",
"789 : 0.014458537\n",
"790 : 0.011331205\n",
"791 : 0.015352869\n",
"792 : 0.014839578\n",
"793 : 0.012153065\n",
"794 : 0.020075902\n",
"795 : 0.021816878\n",
"796 : 0.022544676\n",
"797 : 0.018230867\n",
"798 : 0.016187347\n",
"799 : 0.013606094\n",
"800 : 0.024549238\n",
"\n",
"Saving weights\n",
"\n",
"Epoch # : 9\n",
"801 : 0.012986608\n",
"802 : 0.016526671\n",
"803 : 0.033521313\n",
"804 : 0.014552559\n",
"805 : 0.017782925\n",
"806 : 0.014623911\n",
"807 : 0.01605843\n",
"808 : 0.0182314\n",
"809 : 0.023435459\n",
"810 : 0.013944712\n",
"811 : 0.01235519\n",
"812 : 0.010782383\n",
"813 : 0.019945942\n",
"814 : 0.016715635\n",
"815 : 0.0199111\n",
"816 : 0.015859725\n",
"817 : 0.020479776\n",
"818 : 0.015555834\n",
"819 : 0.015800904\n",
"820 : 0.012691012\n",
"821 : 0.016580176\n",
"822 : 0.019633785\n",
"823 : 0.013375631\n",
"824 : 0.016445559\n",
"825 : 0.012265226\n",
"826 : 0.015153025\n",
"827 : 0.014212745\n",
"828 : 0.016773904\n",
"829 : 0.013638898\n",
"830 : 0.011859362\n",
"831 : 0.0123380935\n",
"832 : 0.022636991\n",
"833 : 0.0189455\n",
"834 : 0.013149257\n",
"835 : 0.017037136\n",
"836 : 0.016768163\n",
"837 : 0.01088644\n",
"838 : 0.01632203\n",
"839 : 0.018193217\n",
"840 : 0.01501472\n",
"841 : 0.019435056\n",
"842 : 0.015862606\n",
"843 : 0.016691906\n",
"844 : 0.01857283\n",
"845 : 0.016344197\n",
"846 : 0.016242089\n",
"847 : 0.011747944\n",
"848 : 0.015716914\n",
"849 : 0.011858826\n",
"850 : 0.00779434\n",
"851 : 0.01895096\n",
"852 : 0.013304486\n",
"853 : 0.016563827\n",
"854 : 0.018470047\n",
"855 : 0.017850008\n",
"856 : 0.00965603\n",
"857 : 0.014101025\n",
"858 : 0.01428686\n",
"859 : 0.019248547\n",
"860 : 0.024916233\n",
"861 : 0.010775974\n",
"862 : 0.014567508\n",
"863 : 0.014941969\n",
"864 : 0.019449465\n",
"865 : 0.009467914\n",
"866 : 0.016918374\n",
"867 : 0.014448612\n",
"868 : 0.0090315\n",
"869 : 0.019191442\n",
"870 : 0.01305908\n",
"871 : 0.012039092\n",
"872 : 0.009064677\n",
"873 : 0.014514124\n",
"874 : 0.013841455\n",
"875 : 0.017515764\n",
"876 : 0.012291939\n",
"877 : 0.014437266\n",
"878 : 0.01352719\n",
"879 : 0.010692399\n",
"880 : 0.009278242\n",
"881 : 0.009847321\n",
"882 : 0.013665096\n",
"883 : 0.012355395\n",
"884 : 0.01276628\n",
"885 : 0.016411718\n",
"886 : 0.014977279\n",
"887 : 0.013307059\n",
"888 : 0.010135107\n",
"889 : 0.01614172\n",
"890 : 0.0103136115\n",
"891 : 0.010267015\n",
"892 : 0.013264958\n",
"893 : 0.025968643\n",
"894 : 0.011141503\n",
"895 : 0.012102563\n",
"896 : 0.018123172\n",
"897 : 0.023255836\n",
"898 : 0.009807067\n",
"899 : 0.01969999\n",
"900 : 0.013991203\n",
"\n",
"Saving weights\n",
"\n",
"Epoch # : 10\n",
"901 : 0.011699881\n",
"902 : 0.010809604\n",
"903 : 0.012074033\n",
"904 : 0.016375568\n",
"905 : 0.013168372\n",
"906 : 0.016547907\n",
"907 : 0.016467936\n",
"908 : 0.010323469\n",
"909 : 0.014534895\n",
"910 : 0.014145259\n",
"911 : 0.012574308\n",
"912 : 0.019426407\n",
"913 : 0.01383512\n",
"914 : 0.014578431\n",
"915 : 0.018004967\n",
"916 : 0.015512292\n",
"917 : 0.022178264\n",
"918 : 0.016814796\n",
"919 : 0.013876684\n",
"920 : 0.01231194\n",
"921 : 0.018006697\n",
"922 : 0.010835016\n",
"923 : 0.01271348\n",
"924 : 0.015747424\n",
"925 : 0.01589224\n",
"926 : 0.014418066\n",
"927 : 0.015718684\n",
"928 : 0.01758476\n",
"929 : 0.018121056\n",
"930 : 0.015334339\n",
"931 : 0.014507584\n",
"932 : 0.016220575\n",
"933 : 0.014739513\n",
"934 : 0.015753098\n",
"935 : 0.015109981\n",
"936 : 0.016242303\n",
"937 : 0.013783309\n",
"938 : 0.018552832\n",
"939 : 0.015021067\n",
"940 : 0.01588847\n",
"941 : 0.007901014\n",
"942 : 0.010706764\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"943 : 0.011875818\n",
"944 : 0.009333177\n",
"945 : 0.02218783\n",
"946 : 0.012051043\n",
"947 : 0.014832691\n",
"948 : 0.014634831\n",
"949 : 0.0131747965\n",
"950 : 0.01689469\n",
"951 : 0.024831302\n",
"952 : 0.01531312\n",
"953 : 0.013977639\n",
"954 : 0.012559077\n",
"955 : 0.011145801\n",
"956 : 0.016348049\n",
"957 : 0.014868401\n",
"958 : 0.0115115205\n",
"959 : 0.0137658\n",
"960 : 0.0151611185\n",
"961 : 0.011013189\n",
"962 : 0.008830121\n",
"963 : 0.014184743\n",
"964 : 0.012976608\n",
"965 : 0.00957146\n",
"966 : 0.011462526\n",
"967 : 0.01051829\n",
"968 : 0.012518496\n",
"969 : 0.01954012\n",
"970 : 0.012688293\n",
"971 : 0.015191065\n",
"972 : 0.011763955\n",
"973 : 0.0078075076\n",
"974 : 0.014996645\n",
"975 : 0.014906674\n",
"976 : 0.010484314\n",
"977 : 0.017806204\n",
"978 : 0.01658817\n",
"979 : 0.008808107\n",
"980 : 0.012971384\n",
"981 : 0.010011482\n",
"982 : 0.018633323\n",
"983 : 0.011638668\n",
"984 : 0.01172273\n",
"985 : 0.016627317\n",
"986 : 0.011784148\n",
"987 : 0.015244544\n",
"988 : 0.010136694\n",
"989 : 0.012415483\n",
"990 : 0.010214868\n",
"991 : 0.009691322\n",
"992 : 0.010831093\n",
"993 : 0.013265954\n",
"994 : 0.018168332\n",
"995 : 0.012348447\n",
"996 : 0.014728769\n",
"997 : 0.011506953\n",
"998 : 0.011384925\n",
"999 : 0.012256512\n",
"1000 : 0.011961029\n",
"\n",
"Saving weights\n",
"\n"
]
}
],
"source": [
"import shutil\n",
"\n",
"with tf.device('/gpu:0'):\n",
" model = JANETModel(NUM_UNITS, num_outputs=1, num_timesteps=TIME_STEPS, output_activation='linear')\n",
"\n",
" optimizer = tf.train.AdamOptimizer(learning_rate=LEARNING_RATE)\n",
" global_step = tf.train.get_or_create_global_step()\n",
"\n",
" # try using different optimizers and different optimizer configs\n",
" model.compile(loss='mse', optimizer=optimizer)\n",
"\n",
" best_loss = 100.\n",
" generator = batch_generator()\n",
"\n",
" loss_history = []\n",
" update_counter = 1\n",
"\n",
" for epoch in range(NUM_EPOCHS):\n",
" print(\"Epoch # : \", epoch + 1)\n",
"\n",
" for step in range(STEPS_PER_EPOCH):\n",
" # get batch dataset\n",
" inputs, targets = next(generator)\n",
" \n",
" # get gradients and loss at this iteration\n",
" gradients, loss_val = grad(model, inputs, targets)\n",
" \n",
" # apply gradients\n",
" optimizer.apply_gradients(gradients, tf.train.get_or_create_global_step())\n",
"\n",
" loss_history.append(loss_val.numpy())\n",
" print(update_counter, \":\", loss_history[-1])\n",
"\n",
" update_counter += 1\n",
" print()\n",
" \n",
" # remove old checkpoint that we no longer need\n",
" if os.path.exists(CHECKPOINTS_DIR):\n",
" shutil.rmtree(CHECKPOINTS_DIR)\n",
" \n",
" # save the checkpoint weights\n",
" checkpoint = tfe.Checkpoint(model=model).save(CHECKPOINT_PATH)\n",
" \n",
" # Optional : Save the weight matrices in Keras format as well\n",
" #model.save_weights('addition_model.h5', overwrite=True)\n",
" \n",
" print(\"Saving weights\")\n",
" print()"
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {
"scrolled": true
},
"outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
"<matplotlib.figure.Figure at 0x20e90f7b9e8>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"import matplotlib.pyplot as plt\n",
"\n",
"# remove noise from initial epochs\n",
"loss_history_plot = list(filter(lambda x: x < 0.25, loss_history))\n",
"\n",
"plt.figure(figsize=(12, 5))\n",
"plt.plot(loss_history_plot)\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Observations\n",
"The model does train for 100 timesteps quite quickly, but it doesnt match the speed of learning by the paper. In the paper, the loss starts to drop close to the 300th iteration and drops all the way to the low 0.0x by the 425th iteration.\n",
"\n",
"Here the loss starts to reduce at the 400th iteration and drops completely by the 600th iteration, much slower than the paper. Also, I had to clip off the first few iterations with extremely high losses as the graph cannot be visualized properly if they were kept."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Loading Checkpoints\n",
"\n",
"Tensorflow Eager makes loading of checkpoints that are fully contained in a Keras model easy. However, there is one small idiosynchacy that needs to be dealt with when loading models in TF 1.7.\n",
"\n",
"Models, after being build, need to be called at least once prior to the checkpoint restoring their weights. If this is not done, it throws obscure errors."
]
},
{
"cell_type": "code",
"execution_count": 24,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Checkpoint path : checkpoints_addition/addition_janet.ckpt-1\n",
"Final average predicted error (should be less than 0.03) : 0.013117644\n"
]
}
],
"source": [
"if os.path.exists(CHECKPOINTS_DIR):\n",
" ckpt_path = tf.train.latest_checkpoint(CHECKPOINTS_DIR)\n",
" print(\"Checkpoint path : \", ckpt_path)\n",
"\n",
" model = JANETModel(NUM_UNITS, num_outputs=1, num_timesteps=TIME_STEPS, output_activation='linear')\n",
"\n",
" model.compile(tf.train.AdamOptimizer(), loss='mse')\n",
" \n",
" # this is where you need to call the model at least once,\n",
" # so that all of its variables can be properly restored\n",
" zeros = tf.zeros((1, TIME_STEPS, 2))\n",
" model(zeros)\n",
" \n",
" # restore the weights\n",
" tfe.Checkpoint(model=model, ).restore(ckpt_path)\n",
" \n",
" # predict one batch to ensure the weights are correctly loaded\n",
" generator = batch_generator()\n",
"\n",
" losses = []\n",
" for i in range(100):\n",
" inputs, outputs = next(generator)\n",
"\n",
" preds = model(inputs)\n",
" loss_val = loss(outputs, preds)\n",
"\n",
" losses.append(loss_val.numpy())\n",
"\n",
" print(\"Final average predicted error (should be less than 0.03) : \", np.mean(losses))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Stepping off the well trodden path\n",
"When building keras models that **only** use Keras layers, there is no issue when loading or saving weights.\n",
"\n",
"However, not every thing can be represented in layers, and a few models may require working directlty with TF Eager variables. In such cases, a different approach is required to save and restore these models.\n",
"\n",
"Lets see the same model without using Keras layers to define them. Here to keep it simple, we will use Eager variables to define the RNN cell, and a Keras Layer to define the classification / regression layer."
]
},
{
"cell_type": "code",
"execution_count": 55,
"metadata": {},
"outputs": [],
"source": [
"class EagerJANETModel(tf.keras.Model):\n",
" \n",
" def __init__(self, input_dim, units, num_outputs, num_timesteps, output_activation='sigmoid', **kwargs):\n",
" super(EagerJANETModel, self).__init__(**kwargs)\n",
" \n",
" self.input_dim = input_dim\n",
" self.units = units\n",
" self.classes = num_outputs\n",
" self.num_timesteps = num_timesteps\n",
"\n",
" # Initialize the forget gate with ChronoInitializer\n",
" # The memory gate is initialized with zeros\n",
" def bias_initializer(_, *args, **kwargs):\n",
" forget_gate = ChronoInitializer(self.num_timesteps)((self.units,), *args, **kwargs)\n",
" \n",
" return tf.keras.backend.concatenate([\n",
" forget_gate,\n",
" tf.keras.initializers.Zeros()((self.units,), *args, **kwargs),\n",
" ])\n",
" \n",
" # Initialize the hidden and recurrent gates\n",
" self.kernel = tf.get_variable('kernel', shape=[input_dim, units * 2], dtype=tf.float32,\n",
" initializer=tf.keras.initializers.glorot_uniform())\n",
" \n",
" self.recurrent_kernel = tf.get_variable('recurrent_kernel', shape=[units, units * 2], dtype=tf.float32,\n",
" initializer=tf.keras.initializers.glorot_uniform())\n",
" \n",
" self.recurrent_bias = tf.get_variable('recurrent_bias', shape=[units * 2], dtype=tf.float32,\n",
" initializer=bias_initializer)\n",
" \n",
" # Initialize the final layer (for classification or regression depending on the output activation)\n",
" self.output_dense = tf.keras.layers.Dense(num_outputs, activation=output_activation)\n",
" \n",
" # we need to create a dictionary of all of the weights which are not in Keras layers\n",
" self.additional_weights = OrderedDict()\n",
" self.additional_weights[self.kernel.name] = self.kernel\n",
" self.additional_weights[self.recurrent_kernel.name] = self.recurrent_kernel\n",
" self.additional_weights[self.recurrent_bias.name] = self.recurrent_bias\n",
" \n",
" \n",
" \"\"\"\n",
" This is required to override a certain issue with how the weights of a model are loaded\n",
" by a checkpoint.\n",
" \n",
" While we could use `model.call(inputs)` directly, the more pythonic way of doing this is\n",
" by using `model(inputs)`.\n",
" \n",
" However, in Eager mode, the ordinary function call does *not* forward the method to the\n",
" models `call` method. This bypasses that issue.\n",
" \"\"\"\n",
" def __call__(self, *args, **kwargs):\n",
" if not tf.executing_eagerly():\n",
" super(EagerJANETModel, self).__call__(*args, **kwargs)\n",
" return self.call(*args, **kwargs)\n",
"\n",
"\n",
" def call(self, inputs, training=None, mask=None):\n",
" # Initialize the hidden memory states\n",
" outputs = []\n",
" states = []\n",
" h_state = tf.zeros((inputs.shape[0], self.units))\n",
" c_state = tf.zeros((inputs.shape[0], self.units))\n",
" \n",
" # Input is in the shape [None, timesteps, input_dim]\n",
" for t in range(inputs.shape[1]):\n",
" ip = inputs[:, t, :] # access the T'th timestep\n",
" \n",
" # Perform the forward pass of the model\n",
" z = tf.matmul(ip, self.kernel)\n",
" z += tf.matmul(h_state, self.recurrent_kernel) + self.recurrent_bias\n",
" \n",
" # Split the output into the forget and memory outputs\n",
" z0 = z[:, :self.units]\n",
" z1 = z[:, self.units: 2 * self.units]\n",
"\n",
" # gate updates\n",
" f = tf.keras.activations.sigmoid(z0)\n",
" c = f * c_state + (1. - f) * tf.nn.tanh(z1)\n",
"\n",
" # state updates\n",
" h = c\n",
" \n",
" # update our previous state\n",
" h_state = h\n",
" c_state = c\n",
" \n",
" # preserve the history of our states\n",
" outputs.append(h)\n",
" states.append([h]) # here we maintain just 1 state variable, can be more\n",
" \n",
" # save memory states to global variables so we can access them later\n",
" self.cell_outputs = tf.stack(outputs, axis=1)\n",
" self.cell_states = states\n",
" \n",
" # perform forward pass of the classifier / regressor\n",
" preds = self.output_dense(outputs[-1])\n",
"\n",
" return preds"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Gradients of a custom model\n",
"\n",
"The gradients of such a model are a little more complicated than the earlier simple `model.trainable_variables`. Since tf.keras Models do not track the weights that are added separately (apart from keras layers, no weights or variables are managed by Keras), we need to write a custom grad function."
]
},
{
"cell_type": "code",
"execution_count": 43,
"metadata": {},
"outputs": [],
"source": [
"def grad(model, X, y):\n",
" with tfe.GradientTape() as tape:\n",
" preds = model(X)\n",
" loss_val = loss(y, preds)\n",
" \n",
" # this is the crucial step : use the dictionary of weights that we manage\n",
" # manually to get all of the weights that are not managed by keras and add\n",
" # them to the list of weights for which we need gradients\n",
" vars = model.trainable_variables + list(model.additional_weights.values())\n",
" grads = tape.gradient(loss_val, vars)\n",
" grad_vars = zip(grads, vars)\n",
"\n",
" return grad_vars, loss_val"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Reason for custom gradient function\n",
"\n",
"Keras has two very convenient methods `add_weight` and `add_variable` to add weights to a layer / model easily. However, TF Eager mode has not implemented these methods yet, and therefore we cant use them directly.\n",
"\n",
"Therefore, we have to take the roundabout way of mixing Keras layers with custom weights and keep a track of them using a dictionary."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Training and Saving the weights of a custom model\n",
"\n",
"Saving and restoring of custom models is exactly the same as before if you are using tf.get_variable to create variables in Eager mode. However, tfe.Variable() is NOT managed by keras, so don't use it inside a Keras model."
]
},
{
"cell_type": "code",
"execution_count": 57,
"metadata": {
"scrolled": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch # : 1\n",
"1 : 1.1621509\n",
"2 : 1.1512643\n",
"3 : 1.1079189\n",
"4 : 0.96773857\n",
"5 : 1.0441301\n",
"6 : 0.83432114\n",
"7 : 0.89465904\n",
"8 : 0.8143208\n",
"9 : 0.6209547\n",
"10 : 0.48464486\n",
"11 : 0.6025023\n",
"12 : 0.49791107\n",
"13 : 0.24446094\n",
"14 : 0.15533014\n",
"15 : 0.20884527\n",
"16 : 0.26446357\n",
"17 : 0.2826485\n",
"18 : 0.25964248\n",
"19 : 0.15775986\n",
"20 : 0.14534274\n",
"21 : 0.11215323\n",
"22 : 0.22895297\n",
"23 : 0.16580443\n",
"24 : 0.15820907\n",
"25 : 0.17097807\n",
"26 : 0.16325451\n",
"27 : 0.14715303\n",
"28 : 0.16494317\n",
"29 : 0.16491885\n",
"30 : 0.15860003\n",
"31 : 0.16606785\n",
"32 : 0.14829527\n",
"33 : 0.21525574\n",
"34 : 0.21541195\n",
"35 : 0.15567198\n",
"36 : 0.13704652\n",
"37 : 0.1394914\n",
"38 : 0.15442346\n",
"39 : 0.18070206\n",
"40 : 0.17316034\n",
"41 : 0.1432823\n",
"42 : 0.10533683\n",
"43 : 0.124987096\n",
"44 : 0.20354164\n",
"45 : 0.17832783\n",
"46 : 0.2143028\n",
"47 : 0.1466952\n",
"48 : 0.20074536\n",
"49 : 0.19540063\n",
"50 : 0.14514315\n",
"51 : 0.2010395\n",
"52 : 0.14188452\n",
"53 : 0.17658226\n",
"54 : 0.17075127\n",
"55 : 0.19622569\n",
"56 : 0.20320654\n",
"57 : 0.14674723\n",
"58 : 0.14215888\n",
"59 : 0.14486471\n",
"60 : 0.14302978\n",
"61 : 0.12705547\n",
"62 : 0.21612404\n",
"63 : 0.20794281\n",
"64 : 0.15753731\n",
"65 : 0.22443554\n",
"66 : 0.1892033\n",
"67 : 0.19395569\n",
"68 : 0.1959509\n",
"69 : 0.15914108\n",
"70 : 0.17604938\n",
"71 : 0.19786243\n",
"72 : 0.13813476\n",
"73 : 0.1283803\n",
"74 : 0.12265948\n",
"75 : 0.1276852\n",
"76 : 0.14710434\n",
"77 : 0.15749604\n",
"78 : 0.1914433\n",
"79 : 0.12846845\n",
"80 : 0.15534262\n",
"81 : 0.18991238\n",
"82 : 0.17339572\n",
"83 : 0.16792797\n",
"84 : 0.17180754\n",
"85 : 0.15204497\n",
"86 : 0.13000168\n",
"87 : 0.15898862\n",
"88 : 0.14421916\n",
"89 : 0.18815175\n",
"90 : 0.143107\n",
"91 : 0.15737225\n",
"92 : 0.14719802\n",
"93 : 0.1403911\n",
"94 : 0.16701642\n",
"95 : 0.17096764\n",
"96 : 0.21338543\n",
"97 : 0.23259921\n",
"98 : 0.17520979\n",
"99 : 0.15143956\n",
"100 : 0.18171377\n",
"\n",
"Saving weights\n",
"\n",
"Epoch # : 2\n",
"101 : 0.17452398\n",
"102 : 0.17398772\n",
"103 : 0.13863611\n",
"104 : 0.18441623\n",
"105 : 0.15283062\n",
"106 : 0.20629707\n",
"107 : 0.18370472\n",
"108 : 0.14907925\n",
"109 : 0.16725613\n",
"110 : 0.19135849\n",
"111 : 0.13866635\n",
"112 : 0.18205793\n",
"113 : 0.14009476\n",
"114 : 0.15662128\n",
"115 : 0.1792382\n",
"116 : 0.2281857\n",
"117 : 0.123989984\n",
"118 : 0.15692824\n",
"119 : 0.13195008\n",
"120 : 0.19298139\n",
"121 : 0.16617489\n",
"122 : 0.15262684\n",
"123 : 0.18068674\n",
"124 : 0.14233781\n",
"125 : 0.081774645\n",
"126 : 0.19077565\n",
"127 : 0.15923782\n",
"128 : 0.21413407\n",
"129 : 0.16872802\n",
"130 : 0.13426398\n",
"131 : 0.18610464\n",
"132 : 0.15890257\n",
"133 : 0.15925054\n",
"134 : 0.19630828\n",
"135 : 0.17393652\n",
"136 : 0.15345761\n",
"137 : 0.17018421\n",
"138 : 0.22589073\n",
"139 : 0.2513543\n",
"140 : 0.13429144\n",
"141 : 0.16034317\n",
"142 : 0.121847115\n",
"143 : 0.15360926\n",
"144 : 0.15041292\n",
"145 : 0.14269821\n",
"146 : 0.17539455\n",
"147 : 0.16700117\n",
"148 : 0.14943682\n",
"149 : 0.16891328\n",
"150 : 0.18500787\n",
"151 : 0.17866719\n",
"152 : 0.14627366\n",
"153 : 0.18164702\n",
"154 : 0.22840938\n",
"155 : 0.22099248\n",
"156 : 0.17002094\n",
"157 : 0.18574873\n",
"158 : 0.16562174\n",
"159 : 0.19159564\n",
"160 : 0.13033281\n",
"161 : 0.18171264\n",
"162 : 0.12429795\n",
"163 : 0.14281641\n",
"164 : 0.12133524\n",
"165 : 0.1362178\n",
"166 : 0.17824344\n",
"167 : 0.17547737\n",
"168 : 0.1661771\n",
"169 : 0.13565828\n",
"170 : 0.16238615\n",
"171 : 0.14953026\n",
"172 : 0.1436495\n",
"173 : 0.17112984\n",
"174 : 0.19373594\n",
"175 : 0.10645476\n",
"176 : 0.18847348\n",
"177 : 0.1684391\n",
"178 : 0.15854037\n",
"179 : 0.20787163\n",
"180 : 0.18528596\n",
"181 : 0.19434765\n",
"182 : 0.1733571\n",
"183 : 0.14257781\n",
"184 : 0.112505585\n",
"185 : 0.14741862\n",
"186 : 0.14854664\n",
"187 : 0.14292587\n",
"188 : 0.13989705\n",
"189 : 0.21581559\n",
"190 : 0.18300533\n",
"191 : 0.13146465\n",
"192 : 0.17924565\n",
"193 : 0.1587025\n",
"194 : 0.21900944\n",
"195 : 0.16573322\n",
"196 : 0.15893798\n",
"197 : 0.13374256\n",
"198 : 0.18655182\n",
"199 : 0.15289505\n",
"200 : 0.10996998\n",
"\n",
"Saving weights\n",
"\n",
"Epoch # : 3\n",
"201 : 0.22186726\n",
"202 : 0.13843639\n",
"203 : 0.12655365\n",
"204 : 0.17936353\n",
"205 : 0.13275121\n",
"206 : 0.21674049\n",
"207 : 0.1770219\n",
"208 : 0.19964287\n",
"209 : 0.19358288\n",
"210 : 0.16396336\n",
"211 : 0.16925587\n",
"212 : 0.13012736\n",
"213 : 0.18705796\n",
"214 : 0.19048297\n",
"215 : 0.16426483\n",
"216 : 0.14111724\n",
"217 : 0.18658395\n",
"218 : 0.10857462\n",
"219 : 0.14023249\n",
"220 : 0.17722149\n",
"221 : 0.16052969\n",
"222 : 0.19683939\n",
"223 : 0.17165703\n",
"224 : 0.15519667\n",
"225 : 0.14174989\n",
"226 : 0.15616634\n",
"227 : 0.1566721\n",
"228 : 0.10799822\n",
"229 : 0.19913803\n",
"230 : 0.16529432\n",
"231 : 0.15814413\n",
"232 : 0.17272209\n",
"233 : 0.13623579\n",
"234 : 0.12063271\n",
"235 : 0.1386844\n",
"236 : 0.1740633\n",
"237 : 0.18553345\n",
"238 : 0.19261423\n",
"239 : 0.1411343\n",
"240 : 0.1138795\n",
"241 : 0.1878658\n",
"242 : 0.14414324\n",
"243 : 0.16839427\n",
"244 : 0.15615423\n",
"245 : 0.17451763\n",
"246 : 0.16337463\n",
"247 : 0.13811496\n",
"248 : 0.15342808\n",
"249 : 0.18247299\n",
"250 : 0.1623476\n",
"251 : 0.13821931\n",
"252 : 0.15814374\n",
"253 : 0.18803948\n",
"254 : 0.14420147\n",
"255 : 0.24036933\n",
"256 : 0.14577095\n",
"257 : 0.11556837\n",
"258 : 0.21571127\n",
"259 : 0.20865135\n",
"260 : 0.15414356\n",
"261 : 0.16331671\n",
"262 : 0.1496009\n",
"263 : 0.16194119\n",
"264 : 0.17524152\n",
"265 : 0.15429325\n",
"266 : 0.106552295\n",
"267 : 0.12640072\n",
"268 : 0.17571625\n",
"269 : 0.15392298\n",
"270 : 0.1721752\n",
"271 : 0.19228888\n",
"272 : 0.12010077\n",
"273 : 0.13864908\n",
"274 : 0.22513275\n",
"275 : 0.17092635\n",
"276 : 0.12420061\n",
"277 : 0.2013131\n",
"278 : 0.15654771\n",
"279 : 0.19830738\n",
"280 : 0.15181726\n",
"281 : 0.136528\n",
"282 : 0.2045887\n",
"283 : 0.12790914\n",
"284 : 0.14993207\n",
"285 : 0.16619124\n",
"286 : 0.15152024\n",
"287 : 0.22683269\n",
"288 : 0.14017643\n",
"289 : 0.15678146\n",
"290 : 0.13380767\n",
"291 : 0.15817656\n",
"292 : 0.17862025\n",
"293 : 0.17340073\n",
"294 : 0.17777954\n",
"295 : 0.1471033\n",
"296 : 0.16436699\n",
"297 : 0.19461739\n",
"298 : 0.16554247\n",
"299 : 0.21333084\n",
"300 : 0.12700416\n",
"\n",
"Saving weights\n",
"\n",
"Epoch # : 4\n",
"301 : 0.16582884\n",
"302 : 0.20898075\n",
"303 : 0.1750078\n",
"304 : 0.09680592\n",
"305 : 0.1415002\n",
"306 : 0.16258782\n",
"307 : 0.14485079\n",
"308 : 0.14240536\n",
"309 : 0.18198383\n",
"310 : 0.12266002\n",
"311 : 0.1777145\n",
"312 : 0.18263713\n",
"313 : 0.18234497\n",
"314 : 0.19265197\n",
"315 : 0.16840225\n",
"316 : 0.16083483\n",
"317 : 0.18632455\n",
"318 : 0.19441384\n",
"319 : 0.14598598\n",
"320 : 0.16428588\n",
"321 : 0.13692603\n",
"322 : 0.17836694\n",
"323 : 0.15277189\n",
"324 : 0.20046842\n",
"325 : 0.15139152\n",
"326 : 0.14826779\n",
"327 : 0.17124844\n",
"328 : 0.16631278\n",
"329 : 0.16364838\n",
"330 : 0.14954187\n",
"331 : 0.15680958\n",
"332 : 0.15733175\n",
"333 : 0.13616699\n",
"334 : 0.15219837\n",
"335 : 0.16273442\n",
"336 : 0.17518112\n",
"337 : 0.15502705\n",
"338 : 0.1610923\n",
"339 : 0.18044762\n",
"340 : 0.12562336\n",
"341 : 0.16819558\n",
"342 : 0.11694947\n",
"343 : 0.1576317\n",
"344 : 0.16258074\n",
"345 : 0.13882609\n",
"346 : 0.13120383\n",
"347 : 0.15734188\n",
"348 : 0.15512449\n",
"349 : 0.116926\n",
"350 : 0.1577882\n",
"351 : 0.17788224\n",
"352 : 0.1452875\n",
"353 : 0.18948543\n",
"354 : 0.13448155\n",
"355 : 0.19038075\n",
"356 : 0.10368092\n",
"357 : 0.16536602\n",
"358 : 0.16239507\n",
"359 : 0.13405462\n",
"360 : 0.15833771\n",
"361 : 0.1340491\n",
"362 : 0.16511582\n",
"363 : 0.21111917\n",
"364 : 0.16407947\n",
"365 : 0.16947262\n",
"366 : 0.15729624\n",
"367 : 0.13916157\n",
"368 : 0.16625576\n",
"369 : 0.1313562\n",
"370 : 0.21306169\n",
"371 : 0.15813984\n",
"372 : 0.17548335\n",
"373 : 0.18683027\n",
"374 : 0.16156863\n",
"375 : 0.15792435\n",
"376 : 0.14098401\n",
"377 : 0.117002755\n",
"378 : 0.16242577\n",
"379 : 0.16778725\n",
"380 : 0.151508\n",
"381 : 0.2292648\n",
"382 : 0.14269243\n",
"383 : 0.12150015\n",
"384 : 0.12754649\n",
"385 : 0.1621392\n",
"386 : 0.18651927\n",
"387 : 0.15823685\n",
"388 : 0.15176812\n",
"389 : 0.15039288\n",
"390 : 0.14135158\n",
"391 : 0.1985403\n",
"392 : 0.11613184\n",
"393 : 0.117393866\n",
"394 : 0.20429797\n",
"395 : 0.15860136\n",
"396 : 0.0956112\n",
"397 : 0.12968332\n",
"398 : 0.15787499\n",
"399 : 0.14649217\n",
"400 : 0.12882891\n",
"\n",
"Saving weights\n",
"\n",
"Epoch # : 5\n",
"401 : 0.16940765\n",
"402 : 0.17282967\n",
"403 : 0.16527462\n",
"404 : 0.18525513\n",
"405 : 0.15709756\n",
"406 : 0.16736652\n",
"407 : 0.17818806\n",
"408 : 0.097625904\n",
"409 : 0.14921002\n",
"410 : 0.16864675\n",
"411 : 0.14301768\n",
"412 : 0.15432054\n",
"413 : 0.16860722\n",
"414 : 0.12987046\n",
"415 : 0.15092702\n",
"416 : 0.12399706\n",
"417 : 0.13464053\n",
"418 : 0.15270242\n",
"419 : 0.14676133\n",
"420 : 0.15388036\n",
"421 : 0.123324625\n",
"422 : 0.17676713\n",
"423 : 0.15827824\n",
"424 : 0.13404016\n",
"425 : 0.15690748\n",
"426 : 0.13800992\n",
"427 : 0.116831884\n",
"428 : 0.14902116\n",
"429 : 0.09627804\n",
"430 : 0.15291607\n",
"431 : 0.10853407\n",
"432 : 0.16284384\n",
"433 : 0.119379066\n",
"434 : 0.14874731\n",
"435 : 0.15419668\n",
"436 : 0.10974857\n",
"437 : 0.16883971\n",
"438 : 0.11674813\n",
"439 : 0.1646753\n",
"440 : 0.14486031\n",
"441 : 0.118678264\n",
"442 : 0.10875624\n",
"443 : 0.14434694\n",
"444 : 0.12476181\n",
"445 : 0.115751676\n",
"446 : 0.12217863\n",
"447 : 0.1143231\n",
"448 : 0.117766485\n",
"449 : 0.1638347\n",
"450 : 0.14917491\n",
"451 : 0.11963312\n",
"452 : 0.14899231\n",
"453 : 0.17762177\n",
"454 : 0.15127851\n",
"455 : 0.15726463\n",
"456 : 0.109702684\n",
"457 : 0.14533967\n",
"458 : 0.16777302\n",
"459 : 0.14042729\n",
"460 : 0.15578504\n",
"461 : 0.118589595\n",
"462 : 0.15657058\n",
"463 : 0.19722565\n",
"464 : 0.1421519\n",
"465 : 0.10931032\n",
"466 : 0.13820136\n",
"467 : 0.13129091\n",
"468 : 0.16731602\n",
"469 : 0.1720289\n",
"470 : 0.15954135\n",
"471 : 0.13708836\n",
"472 : 0.12804864\n",
"473 : 0.10557709\n",
"474 : 0.14164223\n",
"475 : 0.162804\n",
"476 : 0.11445169\n",
"477 : 0.12999205\n",
"478 : 0.18466061\n",
"479 : 0.13450158\n",
"480 : 0.13166414\n",
"481 : 0.15271242\n",
"482 : 0.1544477\n",
"483 : 0.08261827\n",
"484 : 0.11228782\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"485 : 0.16018173\n",
"486 : 0.11356828\n",
"487 : 0.0773872\n",
"488 : 0.096311375\n",
"489 : 0.10748297\n",
"490 : 0.11437002\n",
"491 : 0.10593601\n",
"492 : 0.14546043\n",
"493 : 0.09018659\n",
"494 : 0.13730581\n",
"495 : 0.13203639\n",
"496 : 0.058265384\n",
"497 : 0.105687656\n",
"498 : 0.14168152\n",
"499 : 0.084342234\n",
"500 : 0.06757977\n",
"\n",
"Saving weights\n",
"\n",
"Epoch # : 6\n",
"501 : 0.07952868\n",
"502 : 0.07739203\n",
"503 : 0.18215778\n",
"504 : 0.070043966\n",
"505 : 0.103158206\n",
"506 : 0.08989197\n",
"507 : 0.09987702\n",
"508 : 0.0744272\n",
"509 : 0.07997977\n",
"510 : 0.082261056\n",
"511 : 0.10555805\n",
"512 : 0.10141922\n",
"513 : 0.09325267\n",
"514 : 0.06462287\n",
"515 : 0.04948257\n",
"516 : 0.076712884\n",
"517 : 0.068328306\n",
"518 : 0.06573434\n",
"519 : 0.088249646\n",
"520 : 0.05284898\n",
"521 : 0.122354336\n",
"522 : 0.11334232\n",
"523 : 0.07055436\n",
"524 : 0.14312787\n",
"525 : 0.06019887\n",
"526 : 0.06437059\n",
"527 : 0.1180656\n",
"528 : 0.10588083\n",
"529 : 0.07076761\n",
"530 : 0.104888335\n",
"531 : 0.08440046\n",
"532 : 0.06800649\n",
"533 : 0.069036074\n",
"534 : 0.068309695\n",
"535 : 0.0881553\n",
"536 : 0.07721259\n",
"537 : 0.093113564\n",
"538 : 0.08860032\n",
"539 : 0.07878672\n",
"540 : 0.06988148\n",
"541 : 0.060777463\n",
"542 : 0.060478255\n",
"543 : 0.076639816\n",
"544 : 0.08247948\n",
"545 : 0.046171255\n",
"546 : 0.0806703\n",
"547 : 0.047576375\n",
"548 : 0.06681472\n",
"549 : 0.054443426\n",
"550 : 0.07487729\n",
"551 : 0.0554863\n",
"552 : 0.078380615\n",
"553 : 0.050209466\n",
"554 : 0.055210892\n",
"555 : 0.06537444\n",
"556 : 0.04991782\n",
"557 : 0.06462203\n",
"558 : 0.058467485\n",
"559 : 0.05718185\n",
"560 : 0.045877185\n",
"561 : 0.059709206\n",
"562 : 0.04973532\n",
"563 : 0.06375302\n",
"564 : 0.050100185\n",
"565 : 0.05931464\n",
"566 : 0.0660767\n",
"567 : 0.045827497\n",
"568 : 0.07988288\n",
"569 : 0.056597807\n",
"570 : 0.052850157\n",
"571 : 0.057490874\n",
"572 : 0.057829086\n",
"573 : 0.05794095\n",
"574 : 0.056026462\n",
"575 : 0.050000332\n",
"576 : 0.06738357\n",
"577 : 0.043959126\n",
"578 : 0.08911047\n",
"579 : 0.04238353\n",
"580 : 0.074922666\n",
"581 : 0.075716466\n",
"582 : 0.050799213\n",
"583 : 0.070965655\n",
"584 : 0.0860744\n",
"585 : 0.06125041\n",
"586 : 0.073660895\n",
"587 : 0.061381873\n",
"588 : 0.057871807\n",
"589 : 0.048660744\n",
"590 : 0.09694215\n",
"591 : 0.06389055\n",
"592 : 0.055375315\n",
"593 : 0.09445058\n",
"594 : 0.06687753\n",
"595 : 0.054516762\n",
"596 : 0.06662418\n",
"597 : 0.09523481\n",
"598 : 0.0533651\n",
"599 : 0.053067446\n",
"600 : 0.05575211\n",
"\n",
"Saving weights\n",
"\n",
"Epoch # : 7\n",
"601 : 0.07736846\n",
"602 : 0.0552001\n",
"603 : 0.04574337\n",
"604 : 0.050139762\n",
"605 : 0.048997816\n",
"606 : 0.059650965\n",
"607 : 0.06689899\n",
"608 : 0.09823707\n",
"609 : 0.069393195\n",
"610 : 0.036928907\n",
"611 : 0.06705368\n",
"612 : 0.055885483\n",
"613 : 0.0430569\n",
"614 : 0.07279368\n",
"615 : 0.061935645\n",
"616 : 0.07013589\n",
"617 : 0.03643005\n",
"618 : 0.063131124\n",
"619 : 0.07096699\n",
"620 : 0.042157035\n",
"621 : 0.03867495\n",
"622 : 0.08408603\n",
"623 : 0.07764095\n",
"624 : 0.05555534\n",
"625 : 0.060472623\n",
"626 : 0.05765079\n",
"627 : 0.06436953\n",
"628 : 0.05046819\n",
"629 : 0.04349673\n",
"630 : 0.1141779\n",
"631 : 0.07322818\n",
"632 : 0.028088273\n",
"633 : 0.09079697\n",
"634 : 0.0712476\n",
"635 : 0.050048884\n",
"636 : 0.050365344\n",
"637 : 0.08910738\n",
"638 : 0.058636155\n",
"639 : 0.039823867\n",
"640 : 0.05089892\n",
"641 : 0.07196513\n",
"642 : 0.044761654\n",
"643 : 0.06636423\n",
"644 : 0.07053846\n",
"645 : 0.07046036\n",
"646 : 0.043274842\n",
"647 : 0.043768436\n",
"648 : 0.043393165\n",
"649 : 0.052902985\n",
"650 : 0.04615542\n",
"651 : 0.03631371\n",
"652 : 0.06293922\n",
"653 : 0.04975712\n",
"654 : 0.06299912\n",
"655 : 0.036902577\n",
"656 : 0.0612465\n",
"657 : 0.05225478\n",
"658 : 0.038406122\n",
"659 : 0.038051557\n",
"660 : 0.035875544\n",
"661 : 0.038937297\n",
"662 : 0.026996393\n",
"663 : 0.03250654\n",
"664 : 0.04050415\n",
"665 : 0.031349894\n",
"666 : 0.039157156\n",
"667 : 0.04135289\n",
"668 : 0.02690461\n",
"669 : 0.039351813\n",
"670 : 0.022571877\n",
"671 : 0.040296078\n",
"672 : 0.03502075\n",
"673 : 0.029684229\n",
"674 : 0.044415906\n",
"675 : 0.020568905\n",
"676 : 0.029217523\n",
"677 : 0.025536818\n",
"678 : 0.030433211\n",
"679 : 0.033477653\n",
"680 : 0.052643117\n",
"681 : 0.03186213\n",
"682 : 0.01811907\n",
"683 : 0.037468776\n",
"684 : 0.03816268\n",
"685 : 0.02157622\n",
"686 : 0.03606137\n",
"687 : 0.035710596\n",
"688 : 0.02892365\n",
"689 : 0.021856418\n",
"690 : 0.033755653\n",
"691 : 0.026426092\n",
"692 : 0.026615009\n",
"693 : 0.035342086\n",
"694 : 0.023439674\n",
"695 : 0.03364588\n",
"696 : 0.02135745\n",
"697 : 0.031256534\n",
"698 : 0.028182926\n",
"699 : 0.022584226\n",
"700 : 0.019639416\n",
"\n",
"Saving weights\n",
"\n",
"Epoch # : 8\n",
"701 : 0.026487103\n",
"702 : 0.018899292\n",
"703 : 0.020177234\n",
"704 : 0.025114961\n",
"705 : 0.021267634\n",
"706 : 0.034305304\n",
"707 : 0.033459406\n",
"708 : 0.024987824\n",
"709 : 0.018467478\n",
"710 : 0.022687536\n",
"711 : 0.022842783\n",
"712 : 0.023766942\n",
"713 : 0.02352619\n",
"714 : 0.016435266\n",
"715 : 0.025214477\n",
"716 : 0.02603365\n",
"717 : 0.021770844\n",
"718 : 0.016647704\n",
"719 : 0.022101106\n",
"720 : 0.022887666\n",
"721 : 0.021594457\n",
"722 : 0.025939822\n",
"723 : 0.029485822\n",
"724 : 0.0223991\n",
"725 : 0.02487504\n",
"726 : 0.015103959\n",
"727 : 0.020116957\n",
"728 : 0.024512038\n",
"729 : 0.016456312\n",
"730 : 0.02474121\n",
"731 : 0.021990923\n",
"732 : 0.022414463\n",
"733 : 0.02273764\n",
"734 : 0.018195996\n",
"735 : 0.015131655\n",
"736 : 0.018630486\n",
"737 : 0.025349317\n",
"738 : 0.021785831\n",
"739 : 0.024760304\n",
"740 : 0.020802412\n",
"741 : 0.020911641\n",
"742 : 0.02492898\n",
"743 : 0.021754032\n",
"744 : 0.02905923\n",
"745 : 0.03147601\n",
"746 : 0.014711869\n",
"747 : 0.017211068\n",
"748 : 0.031051988\n",
"749 : 0.025224261\n",
"750 : 0.035324324\n",
"751 : 0.023175893\n",
"752 : 0.02357918\n",
"753 : 0.024354227\n",
"754 : 0.014987478\n",
"755 : 0.019780124\n",
"756 : 0.03400721\n",
"757 : 0.018811047\n",
"758 : 0.024055524\n",
"759 : 0.021259477\n",
"760 : 0.022887683\n",
"761 : 0.026393939\n",
"762 : 0.028880611\n",
"763 : 0.019565068\n",
"764 : 0.029858215\n",
"765 : 0.015567216\n",
"766 : 0.027770605\n",
"767 : 0.021531155\n",
"768 : 0.015297752\n",
"769 : 0.015660807\n",
"770 : 0.024965491\n",
"771 : 0.012396178\n",
"772 : 0.022141272\n",
"773 : 0.027769739\n",
"774 : 0.038974945\n",
"775 : 0.02044248\n",
"776 : 0.021008821\n",
"777 : 0.014710438\n",
"778 : 0.015092858\n",
"779 : 0.015304186\n",
"780 : 0.015131209\n",
"781 : 0.018996174\n",
"782 : 0.02421886\n",
"783 : 0.016571632\n",
"784 : 0.01445453\n",
"785 : 0.018240837\n",
"786 : 0.012433418\n",
"787 : 0.009910008\n",
"788 : 0.01312879\n",
"789 : 0.012092981\n",
"790 : 0.01402854\n",
"791 : 0.021272555\n",
"792 : 0.01436323\n",
"793 : 0.013599489\n",
"794 : 0.014096321\n",
"795 : 0.021129727\n",
"796 : 0.010184152\n",
"797 : 0.014366507\n",
"798 : 0.011746323\n",
"799 : 0.014906851\n",
"800 : 0.017781395\n",
"\n",
"Saving weights\n",
"\n",
"Epoch # : 9\n",
"801 : 0.030030632\n",
"802 : 0.014365942\n",
"803 : 0.012845557\n",
"804 : 0.021550978\n",
"805 : 0.01814298\n",
"806 : 0.02284931\n",
"807 : 0.0237127\n",
"808 : 0.019042442\n",
"809 : 0.019559529\n",
"810 : 0.02465231\n",
"811 : 0.014367839\n",
"812 : 0.025342941\n",
"813 : 0.02344511\n",
"814 : 0.020890992\n",
"815 : 0.01936003\n",
"816 : 0.01787552\n",
"817 : 0.015061614\n",
"818 : 0.0149472775\n",
"819 : 0.016823359\n",
"820 : 0.014147806\n",
"821 : 0.018946111\n",
"822 : 0.015606833\n",
"823 : 0.017901631\n",
"824 : 0.014592238\n",
"825 : 0.01583731\n",
"826 : 0.019414708\n",
"827 : 0.01908092\n",
"828 : 0.012014141\n",
"829 : 0.014311067\n",
"830 : 0.010503138\n",
"831 : 0.010355268\n",
"832 : 0.01041521\n",
"833 : 0.018829599\n",
"834 : 0.020979555\n",
"835 : 0.013075921\n",
"836 : 0.014609251\n",
"837 : 0.0144327665\n",
"838 : 0.01754508\n",
"839 : 0.011855249\n",
"840 : 0.01767914\n",
"841 : 0.0109841265\n",
"842 : 0.007160005\n",
"843 : 0.014209363\n",
"844 : 0.006966999\n",
"845 : 0.008116461\n",
"846 : 0.012841169\n",
"847 : 0.011620861\n",
"848 : 0.014408579\n",
"849 : 0.01348294\n",
"850 : 0.01201184\n",
"851 : 0.009347877\n",
"852 : 0.015454381\n",
"853 : 0.010721075\n",
"854 : 0.013120465\n",
"855 : 0.020850709\n",
"856 : 0.014667712\n",
"857 : 0.017732613\n",
"858 : 0.014344544\n",
"859 : 0.015992176\n",
"860 : 0.01417166\n",
"861 : 0.015028724\n",
"862 : 0.011597834\n",
"863 : 0.013647451\n",
"864 : 0.01674522\n",
"865 : 0.014058041\n",
"866 : 0.014591095\n",
"867 : 0.011233372\n",
"868 : 0.011675154\n",
"869 : 0.015009872\n",
"870 : 0.011655096\n",
"871 : 0.0137200495\n",
"872 : 0.017229075\n",
"873 : 0.012109876\n",
"874 : 0.014179173\n",
"875 : 0.0133238025\n",
"876 : 0.018385964\n",
"877 : 0.017418642\n",
"878 : 0.022792019\n",
"879 : 0.012964742\n",
"880 : 0.0123944925\n",
"881 : 0.0120278625\n",
"882 : 0.015933165\n",
"883 : 0.014295025\n",
"884 : 0.014082588\n",
"885 : 0.011272328\n",
"886 : 0.010991094\n",
"887 : 0.011151891\n",
"888 : 0.01277179\n",
"889 : 0.01260668\n",
"890 : 0.012407446\n",
"891 : 0.012278632\n",
"892 : 0.008859095\n",
"893 : 0.01429797\n",
"894 : 0.014462714\n",
"895 : 0.009702824\n",
"896 : 0.011556063\n",
"897 : 0.013805108\n",
"898 : 0.02522901\n",
"899 : 0.00902943\n",
"900 : 0.0122801205\n",
"\n",
"Saving weights\n",
"\n",
"Epoch # : 10\n",
"901 : 0.00901419\n",
"902 : 0.012053919\n",
"903 : 0.013461079\n",
"904 : 0.015934672\n",
"905 : 0.0075587807\n",
"906 : 0.01101004\n",
"907 : 0.010823178\n",
"908 : 0.011288586\n",
"909 : 0.010599801\n",
"910 : 0.011733476\n",
"911 : 0.013157886\n",
"912 : 0.017425172\n",
"913 : 0.008150251\n",
"914 : 0.014784017\n",
"915 : 0.015230116\n",
"916 : 0.012863653\n",
"917 : 0.011939194\n",
"918 : 0.014136527\n",
"919 : 0.014618843\n",
"920 : 0.0092842\n",
"921 : 0.013114586\n",
"922 : 0.015383261\n",
"923 : 0.0070189848\n",
"924 : 0.01075149\n",
"925 : 0.012399299\n",
"926 : 0.017086044\n",
"927 : 0.021438591\n",
"928 : 0.019750623\n",
"929 : 0.014763551\n",
"930 : 0.034406144\n",
"931 : 0.01718176\n",
"932 : 0.018059038\n",
"933 : 0.010594039\n",
"934 : 0.011660888\n",
"935 : 0.012951542\n",
"936 : 0.012108057\n",
"937 : 0.014253226\n",
"938 : 0.014137793\n",
"939 : 0.014750669\n",
"940 : 0.010612033\n",
"941 : 0.013792441\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"942 : 0.015041292\n",
"943 : 0.0144982245\n",
"944 : 0.0101460675\n",
"945 : 0.011825992\n",
"946 : 0.019285101\n",
"947 : 0.009109263\n",
"948 : 0.0111296605\n",
"949 : 0.014955539\n",
"950 : 0.007480563\n",
"951 : 0.013984419\n",
"952 : 0.008782578\n",
"953 : 0.012421405\n",
"954 : 0.007365109\n",
"955 : 0.0088006975\n",
"956 : 0.009729772\n",
"957 : 0.0069292225\n",
"958 : 0.010098675\n",
"959 : 0.010997141\n",
"960 : 0.012169895\n",
"961 : 0.011444533\n",
"962 : 0.01399242\n",
"963 : 0.0068554753\n",
"964 : 0.00894353\n",
"965 : 0.012242434\n",
"966 : 0.01021545\n",
"967 : 0.012649863\n",
"968 : 0.007302314\n",
"969 : 0.010416012\n",
"970 : 0.006149518\n",
"971 : 0.010594614\n",
"972 : 0.015516937\n",
"973 : 0.00739812\n",
"974 : 0.010529751\n",
"975 : 0.009240895\n",
"976 : 0.013352713\n",
"977 : 0.008962644\n",
"978 : 0.010766876\n",
"979 : 0.009021389\n",
"980 : 0.011382189\n",
"981 : 0.011743964\n",
"982 : 0.0120130135\n",
"983 : 0.015135792\n",
"984 : 0.0071727526\n",
"985 : 0.0064567816\n",
"986 : 0.01175874\n",
"987 : 0.0115865255\n",
"988 : 0.009993628\n",
"989 : 0.01362257\n",
"990 : 0.011902858\n",
"991 : 0.007314901\n",
"992 : 0.015340851\n",
"993 : 0.015841175\n",
"994 : 0.007128427\n",
"995 : 0.007353345\n",
"996 : 0.011006472\n",
"997 : 0.008887305\n",
"998 : 0.009906682\n",
"999 : 0.016090322\n",
"1000 : 0.009503086\n",
"\n",
"Saving weights\n",
"\n"
]
}
],
"source": [
"with tf.device('/gpu:0'):\n",
" input_dim = 2\n",
" model = EagerJANETModel(input_dim, NUM_UNITS, num_outputs=1, num_timesteps=TIME_STEPS, output_activation='linear')\n",
"\n",
" optimizer = tf.train.AdamOptimizer(learning_rate=LEARNING_RATE)\n",
" global_step = tf.train.get_or_create_global_step()\n",
"\n",
" # try using different optimizers and different optimizer configs\n",
" model.compile(loss='mse', optimizer=optimizer)\n",
"\n",
" best_loss = 100.\n",
" generator = batch_generator()\n",
"\n",
" loss_history = []\n",
" update_counter = 1\n",
"\n",
" for epoch in range(NUM_EPOCHS):\n",
" print(\"Epoch # : \", epoch + 1)\n",
"\n",
" for step in range(STEPS_PER_EPOCH):\n",
" # get batch dataset\n",
" inputs, targets = next(generator)\n",
" \n",
" # get gradients and loss at this iteration (uses the new grad method)\n",
" gradients, loss_val = grad(model, inputs, targets)\n",
" \n",
" # apply gradients\n",
" optimizer.apply_gradients(gradients, tf.train.get_or_create_global_step())\n",
"\n",
" loss_history.append(loss_val.numpy())\n",
" print(update_counter, \":\", loss_history[-1])\n",
"\n",
" update_counter += 1\n",
" \n",
" print()\n",
" \n",
" # remove old checkpoint that we no longer need\n",
" if os.path.exists(CHECKPOINTS_DIR):\n",
" shutil.rmtree(CHECKPOINTS_DIR)\n",
" \n",
" # save the checkpoint weights\n",
" # Note how we add the model's additional weights as name-variable pairs\n",
" checkpoint = tfe.Checkpoint(model=model).save(CHECKPOINT_PATH)\n",
" \n",
" # Optional : Save the weight matrices in Keras format as well\n",
" #model.save_weights('addition_model.h5', overwrite=True)\n",
" \n",
" print(\"Saving weights\")\n",
" print()"
]
},
{
"cell_type": "code",
"execution_count": 59,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[<matplotlib.lines.Line2D at 0x20e8fef5320>]"
]
},
"execution_count": 59,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "\n",
"text/plain": [
"<matplotlib.figure.Figure at 0x20df650dac8>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# remove noise from initial epochs\n",
"loss_history_plot = list(filter(lambda x: x < 0.25, loss_history))\n",
"\n",
"plt.figure(figsize=(12, 5))\n",
"plt.plot(loss_history_plot)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Restoring custom models\n",
"\n",
"Because of the added complexity of saving hybrid models, we need to have a slightly different method to load the weights of the custom model from the checkpoint."
]
},
{
"cell_type": "code",
"execution_count": 61,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Checkpoint path : checkpoints_addition/addition_janet.ckpt-1\n",
"Final average predicted error (should be less than 0.03) : 0.010060601\n"
]
}
],
"source": [
"if os.path.exists(CHECKPOINTS_DIR):\n",
" ckpt_path = tf.train.latest_checkpoint(CHECKPOINTS_DIR)\n",
" print(\"Checkpoint path : \", ckpt_path)\n",
"\n",
" model = EagerJANETModel(input_dim, NUM_UNITS, num_outputs=1, num_timesteps=TIME_STEPS, output_activation='linear')\n",
"\n",
" model.compile(tf.train.AdamOptimizer(), loss='mse')\n",
" \n",
" # this is where you need to call the model at least once,\n",
" # so that all of its variables can be properly restored\n",
" zeros = tf.zeros((1, TIME_STEPS, 2))\n",
" model(zeros)\n",
" \n",
" # restore the weights\n",
" # see the difference as to how we pass the additional weight matrices to be loaded\n",
" tfe.Checkpoint(model=model).restore(ckpt_path)\n",
" \n",
" # predict one batch to ensure the weights are correctly loaded\n",
" generator = batch_generator()\n",
"\n",
" losses = []\n",
" for i in range(20):\n",
" inputs, outputs = next(generator)\n",
"\n",
" preds = model(inputs)\n",
" loss_val = loss(outputs, preds)\n",
"\n",
" losses.append(loss_val.numpy())\n",
"\n",
" print(\"Final average predicted error (should be less than 0.03) : \", np.mean(losses))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"anaconda-cloud": {},
"kernelspec": {
"display_name": "Python [conda env:Anaconda3]",
"language": "python",
"name": "conda-env-Anaconda3-py"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.5.2"
}
},
"nbformat": 4,
"nbformat_minor": 1
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment