Last active
August 4, 2016 05:16
-
-
Save xccds/543780a9457faf1f648a37e1f950356f to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 15, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"(0, array([ 0.61638772], dtype=float32), array([ 0.02258076], dtype=float32), 0.02007756)\n", | |
"(20, array([ 0.25472665], dtype=float32), array([ 0.21643446], dtype=float32), 0.001806676)\n", | |
"(40, array([ 0.14641654], dtype=float32), array([ 0.27493113], dtype=float32), 0.0001625907)\n", | |
"(60, array([ 0.11392454], dtype=float32), array([ 0.29247957], dtype=float32), 1.4632288e-05)\n", | |
"(80, array([ 0.10417723], dtype=float32), array([ 0.29774395], dtype=float32), 1.3168243e-06)\n", | |
"(100, array([ 0.10125311], dtype=float32), array([ 0.29932323], dtype=float32), 1.1850291e-07)\n", | |
"(120, array([ 0.10037591], dtype=float32), array([ 0.299797], dtype=float32), 1.0663627e-08)\n", | |
"(140, array([ 0.10011276], dtype=float32), array([ 0.2999391], dtype=float32), 9.5966568e-10)\n", | |
"(160, array([ 0.10003382], dtype=float32), array([ 0.29998174], dtype=float32), 8.6334516e-11)\n", | |
"(180, array([ 0.10001013], dtype=float32), array([ 0.29999453], dtype=float32), 7.7622622e-12)\n", | |
"(200, array([ 0.10000303], dtype=float32), array([ 0.29999837], dtype=float32), 6.9795726e-13)\n" | |
] | |
} | |
], | |
"source": [ | |
"# 线性回归示例\n", | |
"import tensorflow as tf\n", | |
"import numpy as np\n", | |
"\n", | |
"# 根据真实关系创建100个样本 y = x * 0.1 + 0.3\n", | |
"x_data = np.random.rand(100).astype(np.float32)\n", | |
"y_data = x_data * 0.1 + 0.3\n", | |
"\n", | |
"# 其中截距项b和斜率项W是待估计的参数 y_data = W * x_data + b\n", | |
"# 需要先定义参数,并给初始值\n", | |
"W = tf.Variable(tf.random_uniform([1], -1.0, 1.0))\n", | |
"b = tf.Variable(tf.zeros([1]))\n", | |
"# 定义参数和变量间的关系\n", | |
"y = W * x_data + b\n", | |
"\n", | |
"# 定义均方误做为损失函数,后面它是最优化的目标函数\n", | |
"# reduce_mean是对一个tensor求均值并坍缩为一个常数\n", | |
"loss = tf.reduce_mean(tf.square(y - y_data))\n", | |
"# 定义优化器,参数0.5为学习率\n", | |
"optimizer = tf.train.GradientDescentOptimizer(0.5)\n", | |
"# 定义优化器的最小化运算目标\n", | |
"train = optimizer.minimize(loss)\n", | |
"\n", | |
"# 初始化所有变量\n", | |
"init = tf.initialize_all_variables()\n", | |
"\n", | |
"# 启动整个运算流程\n", | |
"sess = tf.Session()\n", | |
"sess.run(init)\n", | |
"\n", | |
"# 通过循环迭代计算\n", | |
"for step in xrange(201):\n", | |
" sess.run(train)\n", | |
" if step % 20 == 0:\n", | |
" print(step, sess.run(W), sess.run(b), sess.run(loss))\n", | |
"\n", | |
"# Learns best fit is W: [0.1], b: [0.3]" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Extracting MNIST_data/train-images-idx3-ubyte.gz\n", | |
"Extracting MNIST_data/train-labels-idx1-ubyte.gz\n", | |
"Extracting MNIST_data/t10k-images-idx3-ubyte.gz\n", | |
"Extracting MNIST_data/t10k-labels-idx1-ubyte.gz\n" | |
] | |
} | |
], | |
"source": [ | |
"# 数字图像识别示例--用多分类逻辑回归方法\n", | |
"# 读取内置数据集\n", | |
"from tensorflow.examples.tutorials.mnist import input_data\n", | |
"mnist = input_data.read_data_sets('MNIST_data', one_hot=True)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [], | |
"source": [ | |
"import tensorflow as tf\n", | |
"sess = tf.InteractiveSession()\n", | |
"# 定义X和Y数据\n", | |
"x = tf.placeholder(tf.float32, shape=[None, 784])\n", | |
"y_ = tf.placeholder(tf.float32, shape=[None, 10])\n", | |
"# 定义参数\n", | |
"W = tf.Variable(tf.zeros([784,10]))\n", | |
"b = tf.Variable(tf.zeros([10]))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 14, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"# 初始化\n", | |
"sess.run(tf.initialize_all_variables())\n", | |
"# 定义Y的生成过程,x和w矩阵相乘后通过softmax映射M\n", | |
"y = tf.nn.softmax(tf.matmul(x,W) + b)\n", | |
"# 定义交叉熵\n", | |
"cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(y), reduction_indices=[1]))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 17, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"# 定义优化器\n", | |
"train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 19, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [], | |
"source": [ | |
"# 开始运算,每50个样本做一批喂入算法\n", | |
"for i in range(1000):\n", | |
" batch = mnist.train.next_batch(50)\n", | |
" train_step.run(feed_dict={x: batch[0], y_: batch[1]})" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 25, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"0.9171\n" | |
] | |
} | |
], | |
"source": [ | |
"# 计算正确率\n", | |
"correct_prediction = tf.equal(tf.argmax(y,1), tf.argmax(y_,1))\n", | |
"accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))\n", | |
"print(accuracy.eval(feed_dict={x: mnist.test.images, y_: mnist.test.labels}))" | |
] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 2", | |
"language": "python", | |
"name": "python2" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 2 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython2", | |
"version": "2.7.11" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 0 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment