Last active
March 14, 2017 18:40
-
-
Save trcook/644c9046799719275f2fda1950d664c2 to your computer and use it in GitHub Desktop.
build tensorflow learn estimator and run -- generate plots along the way. the input function will queue as needed.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import tensorflow.contrib.learn as learn | |
from tensorflow.contrib.learn.python.learn.estimators import model_fn as model_fn_lib | |
import tensorflow as tf | |
import tensorflow.contrib.slim as sm | |
import numpy as np | |
import functools | |
def build_toy_dataset(N=40, noise_std=0.1): | |
D = 1 | |
x = np.concatenate([np.linspace(0, 2, num=N / 2), | |
np.linspace(6, 8, num=N / 2)]) | |
y = np.cos(x) + np.random.normal(0, noise_std, size=N) | |
x = (x - 4.0) / 4.0 | |
x = x.astype(np.float32).reshape((N, D)) | |
y = y.astype(np.float32).reshape((N,1)) | |
return x, y | |
x,y=build_toy_dataset() | |
def train_fn(x,y,queue_all=True): | |
def input_fn(): | |
data=tf.train.batch({'input':x,'target':y},batch_size=1,enqueue_many=queue_all) | |
data['input']=tf.reshape(data['input'],[-1,1]) | |
data['target']=tf.reshape(data['target'],[-1,1]) | |
return {'input':data['input']},data['target'] | |
return input_fn | |
def model_fn(features,targets,mode,params): | |
print(features['input']) | |
net=sm.repeat(features['input'],3,sm.fully_connected, 100,activation_fn=tf.nn.relu6) | |
net=sm.fully_connected(net,1,activation_fn=None) | |
prediction=net | |
loss=tf.losses.mean_squared_error(targets,prediction) | |
train_op = tf.contrib.layers.optimize_loss( | |
loss=loss, | |
global_step=tf.contrib.framework.get_global_step(), | |
learning_rate=params.get('learning_rate',.001), | |
optimizer="Adam") | |
eval_metric_ops={} | |
return learn.ModelFnOps( | |
mode=mode, | |
predictions=prediction, | |
loss=loss, | |
train_op=train_op, | |
eval_metric_ops=eval_metric_ops) | |
esto=learn.Estimator(model_fn,params={"learning_rate":.001}) | |
plt.gcf() | |
plt.gca() | |
for i in range(3): | |
esto.fit(input_fn=train_fn(x,y),steps=10) | |
jj=esto.predict(input_fn=train_fn(x,y,queue_all=False),as_iterable=False) | |
plt.plot(x,y) | |
plt.plot(x,jj.reshape((-1,1))) | |
plt.show() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
for tensorflow 1.0