Skip to content

Instantly share code, notes, and snippets.

@sj82516
Created July 18, 2017 08:57
Show Gist options
  • Select an option

  • Save sj82516/111d13aef679a3d11aa700dd7b89e9ef to your computer and use it in GitHub Desktop.

Select an option

Save sj82516/111d13aef679a3d11aa700dd7b89e9ef to your computer and use it in GitHub Desktop.
import numpy as np
import tensorflow as tf
from matplotlib import pyplot as plt
# 產生100筆資料 , 此為簡單的線性關係 (加點noise)
x_data = np.linspace(-10.0, 10.0, num=100)
y_data = 3.21 * x_data + 2 + np.random.uniform(-10.0,10.0,100)
# 產生tf的graph
x = tf.placeholder(tf.float32, shape=(None,), name="x")
y = tf.placeholder(tf.float32, shape=(None,), name="y")
W = tf.Variable([3.2])
b = tf.Variable(np.random.normal(0,1))
y_pred = W * x + b
loss = tf.reduce_sum(tf.pow(y_pred - y, 2))
train = tf.train.GradientDescentOptimizer(0.00005).minimize(loss)
init = tf.global_variables_initializer()
with tf.Session() as sess:
sess.run(init)
for _ in range(1000):
loss_val,_ = sess.run([loss,train], feed_dict={x: x_data, y: y_data})
y_pred_batch,curr_W, curr_b = sess.run([y_pred,W,b], feed_dict={x:x_data})
print(y_pred_batch,curr_W, curr_b)
plt.figure(1)
plt.scatter(x_data, y_data)
plt.scatter(x_data, y_pred_batch)
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment