Last active
August 23, 2018 02:50
-
-
Save jee/55e3344a61a84a76426e9d446a56e0fd to your computer and use it in GitHub Desktop.
Quick regression analysis example using Tensorflow and the Adam Optimizer
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import tensorflow as tf\n", | |
"import numpy\n", | |
"import matplotlib.pyplot as plt\n", | |
"rng = numpy.random" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# Parameters# Param \n", | |
"learning_rate = 0.01\n", | |
"training_epochs = 1000\n", | |
"display_step = 50" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# Training Data\n", | |
"train_X = numpy.asarray([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197, 198, 199])\n", | |
"train_Y = numpy.asarray([356.09, 350.82, 353.14, 358.56, 352.67, 349.44, 343.1, 339.48, 330.53, 326.72, 320.63, 320.51, 326.06, 326.43, 300.66, 300.69, 303.25, 303.26, 304.1, 300.38, 305.11, 307.96, 311.63, 307.0, 314.72, 319.91, 310.13, 313.47, 314.63, 313.71, 313.42, 316.96, 309.62, 307.62, 307.69, 304.44, 304.0, 306.69, 314.84, 314.12, 321.38, 335.74, 340.36, 342.17, 339.83, 342.15, 335.9, 329.07, 330.48, 327.87, 320.26, 314.22, 312.68, 313.21, 316.55, 320.4, 312.12, 314.62, 326.26, 333.1, 333.0, 339.03, 337.04, 339.9, 344.38, 348.02, 346.59, 353.51, 355.75, 349.13, 342.8, 339.86, 344.56, 345.22, 350.69, 354.14, 346.23, 338.74, 329.86, 340.83, 331.61, 307.87, 312.16, 318.35, 322.35, 328.26, 337.38, 336.17, 336.43, 341.1, 351.05, 355.68, 355.0, 348.73, 339.37, 329.1, 333.52, 331.7, 327.12, 329.79, 325.43, 336.86, 340.69, 331.87, 326.98, 323.24, 315.21, 312.5, 316.31, 313.5, 306.03, 299.48, 290.73, 260.39, 259.58, 252.46, 263.92, 270.19, 297.23, 302.39, 299.36, 300.39, 304.32, 298.81, 299.96, 294.34, 287.34, 294.2, 294.78, 294.87, 286.98, 282.77, 281.21, 281.14, 289.15, 295.62, 297.02, 302.32, 281.63, 288.19, 300.56, 303.38, 303.53, 308.55, 303.98, 298.28, 283.73, 285.19, 286.58, 279.32, 286.39, 280.71, 276.95, 278.0, 277.63, 281.32, 288.3, 286.65, 287.89, 296.27, 292.27, 309.83, 321.79, 320.81, 328.58, 346.49, 343.5, 352.68, 357.96, 364.12, 358.13, 358.19, 356.24, 342.13, 332.99, 334.67, 345.14, 351.56, 348.13, 347.31, 321.09, 305.31, 307.03, 313.26, 323.44, 318.5, 318.0, 314.42, 310.71, 316.62, 320.88, 318.77, 317.47, 299.18, 300.13, 302.06, 307.17, 301.52, 291.12, 298.0])\n", | |
"n_samples = train_X.shape[0]" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# tf Graph Input\n", | |
"X = tf.placeholder(\"float\")\n", | |
"Y = tf.placeholder(\"float\")\n", | |
"\n", | |
"# Set model weights\n", | |
"W = tf.Variable(rng.randn(), name=\"weight\")\n", | |
"b = tf.Variable(rng.randn(), name=\"bias\")" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# Construct a linear model\n", | |
"pred = tf.add(tf.multiply(X, W), b)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# Mean squared error\n", | |
"cost = tf.reduce_sum(tf.pow(pred-Y, 2))/(2*n_samples)\n", | |
"# Gradient descent\n", | |
"# optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost)\n", | |
"\n", | |
"# Adam Optimizer is used, which is one of the current default optimizers \n", | |
"# in deep learning development. Adam stands for “Adaptive Moment Estimation” \n", | |
"# and can be considered as a combination between two other popular optimizers \n", | |
"# AdaGrad and RMSProp\n", | |
"optimizer = tf.train.AdamOptimizer(learning_rate).minimize(cost)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# Initialize the variables (i.e. assign their default value)\n", | |
"init = tf.global_variables_initializer()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 8, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Epoch: 0050 cost= 11212.208984375 W= 1.624199 b= 46.959618\n", | |
"Epoch: 0100 cost= 8066.003417969 W= 1.3073506 b= 93.06213\n", | |
"Epoch: 0150 cost= 5434.843750000 W= 1.0204656 b= 137.98753\n", | |
"Epoch: 0200 cost= 3337.601806641 W= 0.76821214 b= 181.22789\n", | |
"Epoch: 0250 cost= 1820.522827148 W= 0.5456723 b= 222.08707\n", | |
"Epoch: 0300 cost= 883.478271484 W= 0.34873167 b= 259.24255\n", | |
"Epoch: 0350 cost= 439.807617188 W= 0.18674372 b= 289.90445\n", | |
"Epoch: 0400 cost= 309.644042969 W= 0.077910334 b= 310.49673\n", | |
"Epoch: 0450 cost= 295.321746826 W= 0.021207927 b= 321.24008\n", | |
"Epoch: 0500 cost= 300.897033691 W= -0.004233281 b= 326.06818\n", | |
"Epoch: 0550 cost= 305.567230225 W= -0.015110233 b= 328.13446\n", | |
"Epoch: 0600 cost= 307.959197998 W= -0.01971694 b= 329.00995\n", | |
"Epoch: 0650 cost= 309.042053223 W= -0.02165947 b= 329.3792\n", | |
"Epoch: 0700 cost= 309.503479004 W= -0.022464473 b= 329.53223\n", | |
"Epoch: 0750 cost= 309.719940186 W= -0.022837063 b= 329.6031\n", | |
"Epoch: 0800 cost= 309.797424316 W= -0.022970244 b= 329.6284\n", | |
"Epoch: 0850 cost= 309.829315186 W= -0.023024794 b= 329.63876\n", | |
"Epoch: 0900 cost= 309.843109131 W= -0.023048406 b= 329.64325\n", | |
"Epoch: 0950 cost= 309.853973389 W= -0.023067273 b= 329.64682\n", | |
"Epoch: 1000 cost= 309.865814209 W= -0.023087274 b= 329.65063\n", | |
"Optimization Finished!\n", | |
"Training cost= 309.8658 W= -0.023087274 b= 329.65063 \n", | |
"\n" | |
] | |
}, | |
{ | |
"data": { | |
"image/png": "\n", | |
"text/plain": [ | |
"<Figure size 432x288 with 1 Axes>" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
} | |
], | |
"source": [ | |
"# Start training\n", | |
"with tf.Session() as sess:\n", | |
" sess.run(init)\n", | |
"\n", | |
" # Fit all training data\n", | |
" for epoch in range(training_epochs):\n", | |
" for (x, y) in zip(train_X, train_Y):\n", | |
" sess.run(optimizer, feed_dict={X: x, Y: y})\n", | |
"\n", | |
" #Display logs per epoch step\n", | |
" if (epoch+1) % display_step == 0:\n", | |
" c = sess.run(cost, feed_dict={X: train_X, Y:train_Y})\n", | |
" print \"Epoch:\", '%04d' % (epoch+1), \"cost=\", \"{:.9f}\".format(c), \\\n", | |
" \"W=\", sess.run(W), \"b=\", sess.run(b)\n", | |
"\n", | |
" print \"Optimization Finished!\"\n", | |
" training_cost = sess.run(cost, feed_dict={X: train_X, Y: train_Y})\n", | |
" print \"Training cost=\", training_cost, \"W=\", sess.run(W), \"b=\", sess.run(b), '\\n'\n", | |
"\n", | |
" #Graphic display\n", | |
" plt.plot(train_X, train_Y, 'ro', label='Original data')\n", | |
" plt.plot(train_X, sess.run(W) * train_X + sess.run(b), label='Fitted line')\n", | |
" plt.legend()\n", | |
" plt.show()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 9, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"(\"==> Tomorrow's prediction is:\", 325.033184)\n" | |
] | |
} | |
], | |
"source": [ | |
"model_W = -0.02308678\n", | |
"model_b = 329.65054\n", | |
"tomorrow = model_W * 200 + model_b\n", | |
"\n", | |
"print('==> Tomorrow\\'s prediction is:', tomorrow)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 2", | |
"language": "python", | |
"name": "python2" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 2 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython2", | |
"version": "2.7.14" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment