Created
August 2, 2017 03:46
-
-
Save myurasov/d96f932091ac0660982356770fd81ae6 to your computer and use it in GitHub Desktop.
Training with categorical_crossentropy vs sparse_categorical_crossentropy
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"metadata": {}, | |
"cell_type": "markdown", | |
"source": "### Training with `categorical_crossentropy` vs `sparse_categorical_crossentropy` " | |
}, | |
{ | |
"metadata": { | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "import numpy as np\nfrom keras.layers import Input, Dense, Flatten\nfrom keras.models import Model", | |
"execution_count": 1, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": "Using TensorFlow backend.\n", | |
"name": "stderr" | |
} | |
] | |
}, | |
{ | |
"metadata": { | |
"trusted": true, | |
"collapsed": true | |
}, | |
"cell_type": "code", | |
"source": "i = Input(shape=(28, 28, 1), name='input_image')\nx = Flatten()(i)\no1 = Dense(1, activation='sigmoid', name='output_1')(x)\no2 = Dense(10, activation='sigmoid', name='output_2')(x)", | |
"execution_count": 2, | |
"outputs": [] | |
}, | |
{ | |
"metadata": { | |
"trusted": true, | |
"collapsed": true | |
}, | |
"cell_type": "code", | |
"source": "m = Model(inputs=[i], outputs=[o1, o2])", | |
"execution_count": 3, | |
"outputs": [] | |
}, | |
{ | |
"metadata": { | |
"collapsed": true, | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "def model_as_svg(model):\n from IPython.display import SVG\n from keras.utils.vis_utils import model_to_dot\n return SVG(model_to_dot(model, show_shapes=True).create(prog='dot', format='svg'))", | |
"execution_count": 4, | |
"outputs": [] | |
}, | |
{ | |
"metadata": { | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "model_as_svg(m)", | |
"execution_count": 5, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"execution_count": 5, | |
"data": { | |
"text/plain": "<IPython.core.display.SVG object>", | |
"image/svg+xml": "<svg height=\"221pt\" viewBox=\"0.00 0.00 514.00 221.00\" width=\"514pt\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n<g class=\"graph\" id=\"graph0\" transform=\"scale(1 1) rotate(0) translate(4 217)\">\n<title>G</title>\n<polygon fill=\"white\" points=\"-4,4 -4,-217 510,-217 510,4 -4,4\" stroke=\"none\"/>\n<!-- 139690491401608 -->\n<g class=\"node\" id=\"node1\"><title>139690491401608</title>\n<polygon fill=\"none\" points=\"94,-166.5 94,-212.5 412,-212.5 412,-166.5 94,-166.5\" stroke=\"black\"/>\n<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"170\" y=\"-185.8\">input_image: InputLayer</text>\n<polyline fill=\"none\" points=\"246,-166.5 246,-212.5 \" stroke=\"black\"/>\n<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"273.5\" y=\"-197.3\">input:</text>\n<polyline fill=\"none\" points=\"246,-189.5 301,-189.5 \" stroke=\"black\"/>\n<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"273.5\" y=\"-174.3\">output:</text>\n<polyline fill=\"none\" points=\"301,-166.5 301,-212.5 \" stroke=\"black\"/>\n<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"356.5\" y=\"-197.3\">(None, 28, 28, 1)</text>\n<polyline fill=\"none\" points=\"301,-189.5 412,-189.5 \" stroke=\"black\"/>\n<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"356.5\" y=\"-174.3\">(None, 28, 28, 1)</text>\n</g>\n<!-- 139690491401552 -->\n<g class=\"node\" id=\"node2\"><title>139690491401552</title>\n<polygon fill=\"none\" points=\"115,-83.5 115,-129.5 391,-129.5 391,-83.5 115,-83.5\" stroke=\"black\"/>\n<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"170\" y=\"-102.8\">flatten_1: Flatten</text>\n<polyline fill=\"none\" points=\"225,-83.5 225,-129.5 \" stroke=\"black\"/>\n<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"252.5\" y=\"-114.3\">input:</text>\n<polyline fill=\"none\" points=\"225,-106.5 280,-106.5 \" stroke=\"black\"/>\n<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"252.5\" y=\"-91.3\">output:</text>\n<polyline fill=\"none\" points=\"280,-83.5 280,-129.5 \" stroke=\"black\"/>\n<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"335.5\" y=\"-114.3\">(None, 28, 28, 1)</text>\n<polyline fill=\"none\" points=\"280,-106.5 391,-106.5 \" stroke=\"black\"/>\n<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"335.5\" y=\"-91.3\">(None, 784)</text>\n</g>\n<!-- 139690491401608->139690491401552 -->\n<g class=\"edge\" id=\"edge1\"><title>139690491401608->139690491401552</title>\n<path d=\"M253,-166.366C253,-158.152 253,-148.658 253,-139.725\" fill=\"none\" stroke=\"black\"/>\n<polygon fill=\"black\" points=\"256.5,-139.607 253,-129.607 249.5,-139.607 256.5,-139.607\" stroke=\"black\"/>\n</g>\n<!-- 139690491401664 -->\n<g class=\"node\" id=\"node3\"><title>139690491401664</title>\n<polygon fill=\"none\" points=\"0,-0.5 0,-46.5 244,-46.5 244,-0.5 0,-0.5\" stroke=\"black\"/>\n<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"53\" y=\"-19.8\">output_1: Dense</text>\n<polyline fill=\"none\" points=\"106,-0.5 106,-46.5 \" stroke=\"black\"/>\n<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"133.5\" y=\"-31.3\">input:</text>\n<polyline fill=\"none\" points=\"106,-23.5 161,-23.5 \" stroke=\"black\"/>\n<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"133.5\" y=\"-8.3\">output:</text>\n<polyline fill=\"none\" points=\"161,-0.5 161,-46.5 \" stroke=\"black\"/>\n<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"202.5\" y=\"-31.3\">(None, 784)</text>\n<polyline fill=\"none\" points=\"161,-23.5 244,-23.5 \" stroke=\"black\"/>\n<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"202.5\" y=\"-8.3\">(None, 1)</text>\n</g>\n<!-- 139690491401552->139690491401664 -->\n<g class=\"edge\" id=\"edge2\"><title>139690491401552->139690491401664</title>\n<path d=\"M217.204,-83.3664C201.608,-73.723 183.161,-62.3171 166.678,-52.1252\" fill=\"none\" stroke=\"black\"/>\n<polygon fill=\"black\" points=\"168.099,-48.8891 157.753,-46.6068 164.418,-54.8429 168.099,-48.8891\" stroke=\"black\"/>\n</g>\n<!-- 139690491402392 -->\n<g class=\"node\" id=\"node4\"><title>139690491402392</title>\n<polygon fill=\"none\" points=\"262,-0.5 262,-46.5 506,-46.5 506,-0.5 262,-0.5\" stroke=\"black\"/>\n<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"315\" y=\"-19.8\">output_2: Dense</text>\n<polyline fill=\"none\" points=\"368,-0.5 368,-46.5 \" stroke=\"black\"/>\n<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"395.5\" y=\"-31.3\">input:</text>\n<polyline fill=\"none\" points=\"368,-23.5 423,-23.5 \" stroke=\"black\"/>\n<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"395.5\" y=\"-8.3\">output:</text>\n<polyline fill=\"none\" points=\"423,-0.5 423,-46.5 \" stroke=\"black\"/>\n<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"464.5\" y=\"-31.3\">(None, 784)</text>\n<polyline fill=\"none\" points=\"423,-23.5 506,-23.5 \" stroke=\"black\"/>\n<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"464.5\" y=\"-8.3\">(None, 10)</text>\n</g>\n<!-- 139690491401552->139690491402392 -->\n<g class=\"edge\" id=\"edge3\"><title>139690491401552->139690491402392</title>\n<path d=\"M288.796,-83.3664C304.392,-73.723 322.839,-62.3171 339.322,-52.1252\" fill=\"none\" stroke=\"black\"/>\n<polygon fill=\"black\" points=\"341.582,-54.8429 348.247,-46.6068 337.901,-48.8891 341.582,-54.8429\" stroke=\"black\"/>\n</g>\n</g>\n</svg>" | |
}, | |
"metadata": {} | |
} | |
] | |
}, | |
{ | |
"metadata": { | |
"collapsed": true, | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "m.compile(optimizer='sgd', loss=['binary_crossentropy', 'categorical_crossentropy'])", | |
"execution_count": 6, | |
"outputs": [] | |
}, | |
{ | |
"metadata": { | |
"collapsed": true, | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "img = np.random.random(size=(1,28,28,1))", | |
"execution_count": 7, | |
"outputs": [] | |
}, | |
{ | |
"metadata": { | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "p = m.predict(img)\np", | |
"execution_count": 8, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"execution_count": 8, | |
"data": { | |
"text/plain": "[array([[ 0.52987033]], dtype=float32),\n array([[ 0.3056376 , 0.38724217, 0.55166078, 0.33891007, 0.2556529 ,\n 0.2712743 , 0.67876822, 0.61177981, 0.76250398, 0.4633559 ]], dtype=float32)]" | |
}, | |
"metadata": {} | |
} | |
] | |
}, | |
{ | |
"metadata": { | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "r = m.train_on_batch(img, p)\nr", | |
"execution_count": 9, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"execution_count": 9, | |
"data": { | |
"text/plain": "[11.032898, 0.69136167, 10.341537]" | |
}, | |
"metadata": {} | |
} | |
] | |
}, | |
{ | |
"metadata": { | |
"trusted": true, | |
"collapsed": true | |
}, | |
"cell_type": "code", | |
"source": "assert(len(r) == 3)", | |
"execution_count": 10, | |
"outputs": [] | |
}, | |
{ | |
"metadata": { | |
"collapsed": true, | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "m.compile(optimizer='sgd', loss=['binary_crossentropy', 'sparse_categorical_crossentropy']) ", | |
"execution_count": 11, | |
"outputs": [] | |
}, | |
{ | |
"metadata": { | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "r = m.train_on_batch(img, [p[0], np.argmax(p[1], axis=1)])\nr", | |
"execution_count": 12, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"execution_count": 12, | |
"data": { | |
"text/plain": "[2.4943717, 0.69136167, 1.80301]" | |
}, | |
"metadata": {} | |
} | |
] | |
}, | |
{ | |
"metadata": { | |
"collapsed": true, | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "assert(len(r) == 3)", | |
"execution_count": 13, | |
"outputs": [] | |
}, | |
{ | |
"metadata": { | |
"collapsed": true, | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "", | |
"execution_count": null, | |
"outputs": [] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"name": "python3", | |
"display_name": "Python 3", | |
"language": "python" | |
}, | |
"language_info": { | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"name": "python", | |
"version": "3.5.2", | |
"pygments_lexer": "ipython3", | |
"nbconvert_exporter": "python" | |
}, | |
"gist": { | |
"id": "4cb605271fcb16195189305888ee3311", | |
"data": { | |
"description": "Training with categorical_crossentropy vs sparse_categorical_crossentropy", | |
"public": true | |
} | |
}, | |
"_draft": { | |
"nbviewer_url": "https://gist.github.com/4cb605271fcb16195189305888ee3311" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment