Last active
March 17, 2017 15:39
-
-
Save trcook/82696fb193e11a1a4c45b71b4b8b1131 to your computer and use it in GitHub Desktop.
use estimator to generate uncertainty estimates
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import tensorflow.contrib.learn as learn | |
from tensorflow.contrib.learn.python.learn.estimators import model_fn as model_fn_lib | |
import tensorflow as tf | |
import tensorflow.contrib.slim as sm | |
import numpy as np | |
import functools | |
import matplotlib.pyplot as plt | |
def build_toy_dataset(N=40, noise_std=0.1): | |
D = 6 | |
x = np.linspace(-5,5, num=N ) | |
y = np.cos(x) + np.random.normal(0, noise_std, size=N) | |
return x, y | |
def plot_uncertainty(mods,repeats,x,y): | |
vdata=mods.reshape([repeats,-1]) | |
mu=vdata.mean(axis=0) | |
hi=np.percentile(vdata,97.5,axis=0) | |
lo=np.percentile(vdata,2.5,axis=0) | |
plt.plot(x,mu) | |
plt.fill_between(x,y1=hi,y2=lo,alpha=.5) | |
plt.plot(x,hi) | |
plt.plot(x,lo) | |
plt.plot(x,y) | |
plt.show() | |
def train_input_fn(x,y,queue_all=True): | |
def input_fn(): | |
data=tf.train.batch({'input':x,'target':y},batch_size=1,enqueue_many=queue_all) | |
data['input']=tf.reshape(data['input'],[-1,1]) | |
data['target']=tf.reshape(data['target'],[-1,1]) | |
return {'input':data['input']},data['target'] | |
return input_fn | |
def test_input_fn(x,y,repeats=10): | |
x=np.concatenate([x for i in range(repeats)]) | |
print(x.shape) | |
y=np.concatenate([y for i in range(repeats)]) | |
print(y.shape) | |
def input_fn(): | |
data=tf.train.batch({'input':x,'target':y},batch_size=1,enqueue_many=False) | |
data['input']=tf.reshape(data['input'],[-1,1]) | |
data['target']=tf.reshape(data['target'],[-1,1]) | |
return {'input':data['input']},data['target'] | |
return input_fn | |
def model_fn(features,targets,mode,params): | |
print(features) | |
# NETWORK::: | |
net=sm.fully_connected(features['input'],100,activation_fn=tf.nn.relu6) | |
for i in range(3): | |
net=sm.dropout(net,.5,is_training=True,outputs_collections='collecto') | |
net=sm.fully_connected(net,100) | |
net=sm.fully_connected(net,1,activation_fn=None) | |
yhat=net | |
yhat={'yhat':yhat} | |
# END NETWORK | |
# Loss and training Op | |
loss=tf.losses.mean_squared_error(targets,yhat['yhat']) | |
train_op = tf.contrib.layers.optimize_loss( | |
loss=loss, global_step=tf.contrib.framework.get_global_step(), | |
learning_rate=params.get('learning_rate',.001),optimizer="Adam") | |
eval_metric_ops={"MSE":tf.metrics.mean_squared_error(targets,yhat['yhat'])} | |
return learn.ModelFnOps( | |
mode=mode, | |
predictions=yhat, | |
loss=loss, | |
train_op=train_op, | |
eval_metric_ops=eval_metric_ops) | |
# build dataset | |
x,y=build_toy_dataset() | |
# run the model: | |
esto=learn.Estimator(model_fn,params={"learning_rate":.001}) | |
esto.fit(input_fn=train_input_fn(x,y),steps=10000) | |
#generate predictions: | |
predictions=esto.predict(input_fn=test_input_fn(x,y,100),as_iterable=False) | |
# plot: | |
plot_uncertainty(predictions['yhat'],100,x,y) | |
# see how this hangs with out-of-domain prediction (prediction beyond limits of sampled values of x) | |
_1=np.linspace(-10,min(x),30) | |
_2=np.linspace(max(x),10,30) | |
newx=np.concatenate([_1,x,_2]) | |
_1=np.array([np.nan for i in range(30)]) | |
newy=np.concatenate([_1,y,_1]) | |
predictions=esto.predict(input_fn=test_input_fn(newx,newy,100),as_iterable=False) | |
plot_uncertainty(predictions['yhat'],100,newx,newy) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment