Skip to content

Instantly share code, notes, and snippets.

@trcook
Last active March 17, 2017 15:39
Show Gist options
  • Save trcook/82696fb193e11a1a4c45b71b4b8b1131 to your computer and use it in GitHub Desktop.
Save trcook/82696fb193e11a1a4c45b71b4b8b1131 to your computer and use it in GitHub Desktop.
use estimator to generate uncertainty estimates
import tensorflow.contrib.learn as learn
from tensorflow.contrib.learn.python.learn.estimators import model_fn as model_fn_lib
import tensorflow as tf
import tensorflow.contrib.slim as sm
import numpy as np
import functools
import matplotlib.pyplot as plt
def build_toy_dataset(N=40, noise_std=0.1):
D = 6
x = np.linspace(-5,5, num=N )
y = np.cos(x) + np.random.normal(0, noise_std, size=N)
return x, y
def plot_uncertainty(mods,repeats,x,y):
vdata=mods.reshape([repeats,-1])
mu=vdata.mean(axis=0)
hi=np.percentile(vdata,97.5,axis=0)
lo=np.percentile(vdata,2.5,axis=0)
plt.plot(x,mu)
plt.fill_between(x,y1=hi,y2=lo,alpha=.5)
plt.plot(x,hi)
plt.plot(x,lo)
plt.plot(x,y)
plt.show()
def train_input_fn(x,y,queue_all=True):
def input_fn():
data=tf.train.batch({'input':x,'target':y},batch_size=1,enqueue_many=queue_all)
data['input']=tf.reshape(data['input'],[-1,1])
data['target']=tf.reshape(data['target'],[-1,1])
return {'input':data['input']},data['target']
return input_fn
def test_input_fn(x,y,repeats=10):
x=np.concatenate([x for i in range(repeats)])
print(x.shape)
y=np.concatenate([y for i in range(repeats)])
print(y.shape)
def input_fn():
data=tf.train.batch({'input':x,'target':y},batch_size=1,enqueue_many=False)
data['input']=tf.reshape(data['input'],[-1,1])
data['target']=tf.reshape(data['target'],[-1,1])
return {'input':data['input']},data['target']
return input_fn
def model_fn(features,targets,mode,params):
print(features)
# NETWORK:::
net=sm.fully_connected(features['input'],100,activation_fn=tf.nn.relu6)
for i in range(3):
net=sm.dropout(net,.5,is_training=True,outputs_collections='collecto')
net=sm.fully_connected(net,100)
net=sm.fully_connected(net,1,activation_fn=None)
yhat=net
yhat={'yhat':yhat}
# END NETWORK
# Loss and training Op
loss=tf.losses.mean_squared_error(targets,yhat['yhat'])
train_op = tf.contrib.layers.optimize_loss(
loss=loss, global_step=tf.contrib.framework.get_global_step(),
learning_rate=params.get('learning_rate',.001),optimizer="Adam")
eval_metric_ops={"MSE":tf.metrics.mean_squared_error(targets,yhat['yhat'])}
return learn.ModelFnOps(
mode=mode,
predictions=yhat,
loss=loss,
train_op=train_op,
eval_metric_ops=eval_metric_ops)
# build dataset
x,y=build_toy_dataset()
# run the model:
esto=learn.Estimator(model_fn,params={"learning_rate":.001})
esto.fit(input_fn=train_input_fn(x,y),steps=10000)
#generate predictions:
predictions=esto.predict(input_fn=test_input_fn(x,y,100),as_iterable=False)
# plot:
plot_uncertainty(predictions['yhat'],100,x,y)
# see how this hangs with out-of-domain prediction (prediction beyond limits of sampled values of x)
_1=np.linspace(-10,min(x),30)
_2=np.linspace(max(x),10,30)
newx=np.concatenate([_1,x,_2])
_1=np.array([np.nan for i in range(30)])
newy=np.concatenate([_1,y,_1])
predictions=esto.predict(input_fn=test_input_fn(newx,newy,100),as_iterable=False)
plot_uncertainty(predictions['yhat'],100,newx,newy)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment