Skip to content

Instantly share code, notes, and snippets.

@trcook
Last active March 14, 2017 18:40
Show Gist options
  • Save trcook/644c9046799719275f2fda1950d664c2 to your computer and use it in GitHub Desktop.
Save trcook/644c9046799719275f2fda1950d664c2 to your computer and use it in GitHub Desktop.
build tensorflow learn estimator and run -- generate plots along the way. the input function will queue as needed.
import tensorflow.contrib.learn as learn
from tensorflow.contrib.learn.python.learn.estimators import model_fn as model_fn_lib
import tensorflow as tf
import tensorflow.contrib.slim as sm
import numpy as np
import functools
def build_toy_dataset(N=40, noise_std=0.1):
D = 1
x = np.concatenate([np.linspace(0, 2, num=N / 2),
np.linspace(6, 8, num=N / 2)])
y = np.cos(x) + np.random.normal(0, noise_std, size=N)
x = (x - 4.0) / 4.0
x = x.astype(np.float32).reshape((N, D))
y = y.astype(np.float32).reshape((N,1))
return x, y
x,y=build_toy_dataset()
def train_fn(x,y,queue_all=True):
def input_fn():
data=tf.train.batch({'input':x,'target':y},batch_size=1,enqueue_many=queue_all)
data['input']=tf.reshape(data['input'],[-1,1])
data['target']=tf.reshape(data['target'],[-1,1])
return {'input':data['input']},data['target']
return input_fn
def model_fn(features,targets,mode,params):
print(features['input'])
net=sm.repeat(features['input'],3,sm.fully_connected, 100,activation_fn=tf.nn.relu6)
net=sm.fully_connected(net,1,activation_fn=None)
prediction=net
loss=tf.losses.mean_squared_error(targets,prediction)
train_op = tf.contrib.layers.optimize_loss(
loss=loss,
global_step=tf.contrib.framework.get_global_step(),
learning_rate=params.get('learning_rate',.001),
optimizer="Adam")
eval_metric_ops={}
return learn.ModelFnOps(
mode=mode,
predictions=prediction,
loss=loss,
train_op=train_op,
eval_metric_ops=eval_metric_ops)
esto=learn.Estimator(model_fn,params={"learning_rate":.001})
plt.gcf()
plt.gca()
for i in range(3):
esto.fit(input_fn=train_fn(x,y),steps=10)
jj=esto.predict(input_fn=train_fn(x,y,queue_all=False),as_iterable=False)
plt.plot(x,y)
plt.plot(x,jj.reshape((-1,1)))
plt.show()
@trcook
Copy link
Author

trcook commented Mar 14, 2017

for tensorflow 1.0

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment