Skip to content

Instantly share code, notes, and snippets.

@Orbifold
Created August 23, 2018 04:59
Show Gist options
  • Select an option

  • Save Orbifold/151a553038f955ebed0c8e2dccda01d6 to your computer and use it in GitHub Desktop.

Select an option

Save Orbifold/151a553038f955ebed0c8e2dccda01d6 to your computer and use it in GitHub Desktop.
Using TensorFlow LinearRegression estimator.
#!/usr/bin/env python3
# This demonstrates the usage of input_fn with numpy data
# and estimators.
import tensorflow as tf
tf.enable_eager_execution()
assert tf.executing_eagerly()
import tensorflow.contrib.eager as tfe
# too much info otherwise
tf.logging.set_verbosity(tf.logging.ERROR)
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sea
sea.set()
sea.set
%matplotlib inline
X = np.arange(0,100)
np.random.shuffle(X)
# set a linear y = ax+b
y = 3*X +2 + np.random.normal(size=[100])
# Create feature column and estimator
column = tf.feature_column.numeric_column('x')
lin_reg = tf.estimator.LinearRegressor(feature_columns=[column])
# Train the estimator
train_input = tf.estimator.inputs.numpy_input_fn(
x={"x": X},
y=y, shuffle=False,num_epochs=2500,batch_size=1) ###Edited here
lin_reg.train(train_input)
# Make two predictions for fun
predict_input = tf.estimator.inputs.numpy_input_fn(
x={"x": np.array([1.9, 1.4], dtype=np.float32)},
num_epochs=1, shuffle=False)
results = lin_reg.predict(predict_input)
# Print result
for value in results:
print(value['predictions'])
# the input looks like
plt.plot(X,y)
# the reality vs prediction looks like
predict_input = tf.estimator.inputs.numpy_input_fn(
x={"x": X},
num_epochs=1, shuffle=False)
results = list(lin_reg.predict(predict_input))
fig, ax = plt.subplots()
sea.regplot(X,y, ax=ax)
ax2 = ax.twinx()
sea.regplot(X, np.array([x["predictions"][0] for x in list(results) ]) , ax=ax2, color='r')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment