Skip to content

Instantly share code, notes, and snippets.

@jskDr
Created November 22, 2016 02:57
Show Gist options
  • Save jskDr/289fb814b9b67c37d3486628ef26806b to your computer and use it in GitHub Desktop.
Save jskDr/289fb814b9b67c37d3486628ef26806b to your computer and use it in GitHub Desktop.
Keras - User Specific Loss Function Using Backend
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Keras - User Specific Loss Function Using Backend\n",
"- IPython notebook version of the example in http://stackoverflow.com/questions/38684768/customize-keras-loss-function-in-a-way-that-the-y-true-will-depend-on-y-pred"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"%matplotlib inline\n",
"import numpy as np"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Using gpu device 0: GeForce GTX 1070 (CNMeM is disabled, cuDNN 5005)\n",
"Using Theano backend.\n"
]
}
],
"source": [
"import theano\n",
"from keras import backend as K\n",
"from keras.layers import Dense\n",
"from keras.models import Sequential"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"def customized_loss(y_true, y_pred):\n",
" #loss = K.switch(K.equal(y_true, -1), 0, K.square(y_true-y_pred))\n",
" loss = np.square(y_true-y_pred)\n",
" return K.sum(loss)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"model = Sequential([ Dense(3, input_shape=(4,)) ])\n",
"model.compile(loss=customized_loss, optimizer='sgd')"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"collapsed": false
},
"outputs": [
{
"data": {
"text/plain": [
"(array([[ 0.52077309, 0.16107648, 0.28060555, 0.68680049]]),\n",
" array([[ 1, -1, 0]]))"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"x = np.random.random((1, 4))\n",
"y = np.array([[1,-1,0]])\n",
"x, y"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"collapsed": false
},
"outputs": [
{
"data": {
"text/plain": [
"(array([[ 0.29694095, -0.39778247, 0.13337782],\n",
" [-0.82495385, -0.65100169, 0.43441007],\n",
" [-0.49311569, -0.50998557, 0.22923476],\n",
" [-0.20695877, -0.85389435, 0.4840104 ]], dtype=float32),\n",
" array([ 0., 0., 0.], dtype=float32))"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"W, b = model.get_weights()\n",
"W, b"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[[-0.25875217 -1.04157531 0.53617591]] [[-0.25875219 -1.04157531 0.53617595]] [[ 2.01450578e-08 -1.82999171e-09 -4.01655094e-08]]\n"
]
}
],
"source": [
"output = model.predict(x)\n",
"output_calc = np.dot( x, W) + b\n",
"err = output - output_calc\n",
"print( output, output_calc, err)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"1/1 [==============================] - 0s\n",
"1.87367010117 1.87367022163\n"
]
}
],
"source": [
"loss_calc = np.sum(np.square(y-output_calc))\n",
"print( model.evaluate(x, y), loss_calc ) # keras's loss"
]
}
],
"metadata": {
"anaconda-cloud": {},
"kernelspec": {
"display_name": "Python [conda root]",
"language": "python",
"name": "conda-root-py"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.5.2"
},
"toc": {
"nav_menu": {
"height": "45px",
"width": "252px"
},
"navigate_menu": true,
"number_sections": true,
"sideBar": true,
"threshold": 4,
"toc_cell": false,
"toc_section_display": "block",
"toc_window_display": false
}
},
"nbformat": 4,
"nbformat_minor": 1
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment