Last active
February 23, 2019 08:16
-
-
Save suyash/539971f5d90b25fc95ac4714208c79ee to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"Using TensorFlow backend.\n" | |
] | |
} | |
], | |
"source": [ | |
"from keras.layers import Dense, Input\n", | |
"from keras.models import Model" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"### Working Example" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"class Model1(Model):\n", | |
" def __init__(self, **kwargs):\n", | |
" super(Model1, self).__init__(**kwargs)\n", | |
" \n", | |
" self.m1 = Dense(64)\n", | |
" self.m2 = Dense(64)\n", | |
" self.m3 = Dense(64)\n", | |
" self.l = Dense(32, activation=\"relu\")\n", | |
" \n", | |
" def call(self, i):\n", | |
" o = i\n", | |
" o = self.m1(o)\n", | |
" o = self.m2(o)\n", | |
" o = self.m3(o)\n", | |
" return self.l(o)\n", | |
" \n", | |
" def compute_output_shape(self, input_shape):\n", | |
" return (input_shape[0], 32)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"i = Input((32,))\n", | |
"net = Model1()(i)\n", | |
"net = Dense(2, activation=\"softmax\")(net)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"model = Model(i, net)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"_________________________________________________________________\n", | |
"Layer (type) Output Shape Param # \n", | |
"=================================================================\n", | |
"input_1 (InputLayer) (None, 32) 0 \n", | |
"_________________________________________________________________\n", | |
"model1_1 (Model1) (None, 32) 12512 \n", | |
"_________________________________________________________________\n", | |
"dense_5 (Dense) (None, 2) 66 \n", | |
"=================================================================\n", | |
"Total params: 12,578\n", | |
"Trainable params: 12,578\n", | |
"Non-trainable params: 0\n", | |
"_________________________________________________________________\n" | |
] | |
} | |
], | |
"source": [ | |
"model.summary()" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"### Buggy Example" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"class Model2(Model):\n", | |
" def __init__(self, **kwargs):\n", | |
" super(Model2, self).__init__(**kwargs)\n", | |
" \n", | |
" self.m = [Dense(64) for _ in range(3)]\n", | |
" self.l = Dense(32, activation=\"relu\")\n", | |
" \n", | |
" def call(self, i):\n", | |
" o = i\n", | |
" for l in self.m:\n", | |
" print(l)\n", | |
" o = l(o)\n", | |
" return self.l(o)\n", | |
" \n", | |
" def compute_output_shape(self, input_shape):\n", | |
" return (input_shape[0], 32)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"<keras.layers.core.Dense object at 0xb387f5358>\n", | |
"<keras.layers.core.Dense object at 0xb387f5278>\n", | |
"<keras.layers.core.Dense object at 0xb387f57f0>\n" | |
] | |
} | |
], | |
"source": [ | |
"i = Input((32,))\n", | |
"net = Model2()(i)\n", | |
"net = Dense(2, activation=\"softmax\")(net)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 8, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"model = Model(i, net)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 9, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"_________________________________________________________________\n", | |
"Layer (type) Output Shape Param # \n", | |
"=================================================================\n", | |
"input_2 (InputLayer) (None, 32) 0 \n", | |
"_________________________________________________________________\n", | |
"model2_1 (Model2) (None, 32) 2080 \n", | |
"_________________________________________________________________\n", | |
"dense_10 (Dense) (None, 2) 66 \n", | |
"=================================================================\n", | |
"Total params: 2,146\n", | |
"Trainable params: 2,146\n", | |
"Non-trainable params: 0\n", | |
"_________________________________________________________________\n" | |
] | |
} | |
], | |
"source": [ | |
"model.summary()" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"NOTE: should say 12,578" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 10, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"[<keras.engine.input_layer.InputLayer at 0xb387f52b0>,\n", | |
" <__main__.Model2 at 0xb387f51d0>,\n", | |
" <keras.layers.core.Dense at 0xb387f5400>]" | |
] | |
}, | |
"execution_count": 10, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"model.layers" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 11, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"[<keras.layers.core.Dense at 0xb387f5898>]" | |
] | |
}, | |
"execution_count": 11, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"model.layers[1].layers" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"NOTE: only 1 instead of 4" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"### tensorflow.keras" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 12, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import tensorflow as tf" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 13, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"class Model3(tf.keras.models.Model):\n", | |
" def __init__(self, **kwargs):\n", | |
" super(Model3, self).__init__(**kwargs)\n", | |
" \n", | |
" self.m = [tf.keras.layers.Dense(64) for _ in range(3)]\n", | |
" self.l = tf.keras.layers.Dense(32, activation=\"relu\")\n", | |
" \n", | |
" def call(self, i):\n", | |
" o = i\n", | |
" for l in self.m:\n", | |
" print(l)\n", | |
" o = l(o)\n", | |
" return self.l(o)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 14, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"<tensorflow.python.keras.layers.core.Dense object at 0xb388c9be0>\n", | |
"<tensorflow.python.keras.layers.core.Dense object at 0xb388c9e48>\n", | |
"<tensorflow.python.keras.layers.core.Dense object at 0xb388d40b8>\n" | |
] | |
} | |
], | |
"source": [ | |
"i = tf.keras.layers.Input((32,))\n", | |
"net = Model3()(i)\n", | |
"net = tf.keras.layers.Dense(2, activation=\"softmax\")(net)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 15, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"model = tf.keras.models.Model(i, net)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 16, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"_________________________________________________________________\n", | |
"Layer (type) Output Shape Param # \n", | |
"=================================================================\n", | |
"input_1 (InputLayer) (None, 32) 0 \n", | |
"_________________________________________________________________\n", | |
"model3 (Model3) (None, 32) 12512 \n", | |
"_________________________________________________________________\n", | |
"dense_4 (Dense) (None, 2) 66 \n", | |
"=================================================================\n", | |
"Total params: 12,578\n", | |
"Trainable params: 12,578\n", | |
"Non-trainable params: 0\n", | |
"_________________________________________________________________\n" | |
] | |
} | |
], | |
"source": [ | |
"model.summary()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 17, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"[<tensorflow.python.keras.engine.input_layer.InputLayer at 0xb388c90b8>,\n", | |
" <__main__.Model3 at 0xb388c9080>,\n", | |
" <tensorflow.python.keras.layers.core.Dense at 0xb388c9a20>]" | |
] | |
}, | |
"execution_count": 17, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"model.layers" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 18, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"[<tensorflow.python.keras.layers.core.Dense at 0xb388c9be0>,\n", | |
" <tensorflow.python.keras.layers.core.Dense at 0xb388c9e48>,\n", | |
" <tensorflow.python.keras.layers.core.Dense at 0xb388d40b8>,\n", | |
" <tensorflow.python.keras.layers.core.Dense at 0xb388d4240>]" | |
] | |
}, | |
"execution_count": 18, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"model.layers[1].layers" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"NOTE: working as expected" | |
] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.6.8" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment