Created
August 23, 2018 04:59
-
-
Save Orbifold/151a553038f955ebed0c8e2dccda01d6 to your computer and use it in GitHub Desktop.
Using TensorFlow LinearRegression estimator.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| #!/usr/bin/env python3 | |
| # This demonstrates the usage of input_fn with numpy data | |
| # and estimators. | |
| import tensorflow as tf | |
| tf.enable_eager_execution() | |
| assert tf.executing_eagerly() | |
| import tensorflow.contrib.eager as tfe | |
| # too much info otherwise | |
| tf.logging.set_verbosity(tf.logging.ERROR) | |
| import numpy as np | |
| import matplotlib.pyplot as plt | |
| import seaborn as sea | |
| sea.set() | |
| sea.set | |
| %matplotlib inline | |
| X = np.arange(0,100) | |
| np.random.shuffle(X) | |
| # set a linear y = ax+b | |
| y = 3*X +2 + np.random.normal(size=[100]) | |
| # Create feature column and estimator | |
| column = tf.feature_column.numeric_column('x') | |
| lin_reg = tf.estimator.LinearRegressor(feature_columns=[column]) | |
| # Train the estimator | |
| train_input = tf.estimator.inputs.numpy_input_fn( | |
| x={"x": X}, | |
| y=y, shuffle=False,num_epochs=2500,batch_size=1) ###Edited here | |
| lin_reg.train(train_input) | |
| # Make two predictions for fun | |
| predict_input = tf.estimator.inputs.numpy_input_fn( | |
| x={"x": np.array([1.9, 1.4], dtype=np.float32)}, | |
| num_epochs=1, shuffle=False) | |
| results = lin_reg.predict(predict_input) | |
| # Print result | |
| for value in results: | |
| print(value['predictions']) | |
| # the input looks like | |
| plt.plot(X,y) | |
| # the reality vs prediction looks like | |
| predict_input = tf.estimator.inputs.numpy_input_fn( | |
| x={"x": X}, | |
| num_epochs=1, shuffle=False) | |
| results = list(lin_reg.predict(predict_input)) | |
| fig, ax = plt.subplots() | |
| sea.regplot(X,y, ax=ax) | |
| ax2 = ax.twinx() | |
| sea.regplot(X, np.array([x["predictions"][0] for x in list(results) ]) , ax=ax2, color='r') |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment