Skip to content

Instantly share code, notes, and snippets.

@sergeyf
Last active July 21, 2020 22:08
Show Gist options
  • Save sergeyf/cf20b2759a7d38035f30384769bed9df to your computer and use it in GitHub Desktop.
Save sergeyf/cf20b2759a7d38035f30384769bed9df to your computer and use it in GitHub Desktop.
import matplotlib.pyplot as plt
import numpy as np
import seaborn
from keras.layers import Input, Dense, merge, ELU, Dropout
from keras.models import Model
from keras.regularizers import l2
from keras import backend as K
from keras.optimizers import rmsprop, adam
# From Bishop's paper - to get point estimates
def mdnOutputToPreds(output,mode=0):
m = output.shape[-1] / 3
alphas = output[:, 0 * m: 1 * m]
betas = output[:, 1 * m: 2 * m]
pis = output[:, 2 * m: 3 * m]
means = alphas/(alphas + betas)
means[alphas==0] = 0
sigmas = np.sqrt( alphas*betas / ( (alphas+betas)**2 * (alphas + betas + 1) ) )
if mode == 0:
max_components = np.argmax(pis/sigmas,axis=1)
elif mode == 1:
max_components = np.argmax(pis,axis=1)
return means[np.arange(len(means)),max_components], sigmas[np.arange(len(means)),max_components]
# Bishop's exponential activation results in explosively large sigmas
def safeBeta(x):
pos = K.relu(x)
neg = (x - K.abs(x)) * 0.5
return K.clip(pos + K.exp(neg), 1e-6, 1000)
#
def gammaln(x):
# fast approximate gammaln from Paul Mineiro
# http://www.machinedlearnings.com/2011/06/faster-lda.html
logterm = K.log(x * (1.0 + x) * (2.0 + x))
xp3 = 3.0 + x
return -2.081061466 - x + 0.0833333 / xp3 - logterm + (2.5 + x) * K.log(xp3)
# negative log likelihood loss for a mixture of betas
def neg_beta_mixture_likelihood(true, parameters):
m = K.shape(parameters)[-1] // 3
alphas = parameters[:, 0 * m: 1 * m]
betas = parameters[:, 1 * m: 2 * m]
pis = parameters[:, 2 * m: 3 * m]
true_repeated = K.repeat_elements(K.clip(true,1e-6,1-1e-6), m, axis=-1)
d1 = (alphas - 1.0) * K.log(true_repeated)
d2 = (betas - 1.0) * K.log(1.0 - true_repeated)
f1 = d1 + d2
f2 = gammaln(alphas)
f3 = gammaln(betas)
f4 = gammaln(alphas + betas)
exponent = f1 + f4 - f2 - f3
max_exponent = K.max(exponent, axis=-1, keepdims=True)
max_exponent_repeated = K.repeat_elements(max_exponent, m, axis=-1)
likelihood = pis * K.exp( exponent - max_exponent_repeated )
return K.mean( -(K.log(K.sum(likelihood,axis=-1)) + max_exponent), axis=-1 )
# multiple output function
def f_of_x(X,w):
n,d = X.shape
X_dot_w = np.dot(X,w)
y = np.zeros(n)
# the inner product randomly goes through a sin
# or a cos
cos_flag = np.random.randn(n) < 0.0
sin_flag = ~cos_flag
y[cos_flag] = np.cos(X_dot_w[cos_flag])
y[sin_flag] = np.sin(X_dot_w[sin_flag])
return y
# generate some simulated data
d = 10
ntr = 100000
nts = 10000
w = np.random.rand(d)
Xtr = np.random.randn(ntr,d)
ytr = f_of_x(Xtr,w)
Xts = np.random.randn(nts,d)
yts = f_of_x(Xts,w)
# make sure everything is between 0 and 1
ymin,ymax = ytr.min(),ytr.max()
ytr = (ytr - ymin)/(ymax - ymin)
yts = (yts - ymin)/(ymax - ymin)
# network architecture
m = 50 # number of components
lam = 1e-6 # l2 regularizer for each layer
dropout_p = 0.3
inputs = Input(shape=(d,))
mlp = Dense(100, W_regularizer=l2(lam))(inputs)
mlp = ELU()(mlp)
mlp = Dropout(dropout_p)(mlp)
mlp = Dense(50, W_regularizer=l2(lam))(mlp)
mlp = ELU()(mlp)
alphas = Dense(m, W_regularizer=l2(lam), activation=safeBeta)(mlp)
betas = Dense(m, W_regularizer=l2(lam), activation=safeBeta)(mlp)
pis = Dense(m, W_regularizer=l2(lam), activation='softmax')(mlp)
parameters = merge([alphas,betas,pis],mode='concat')
model = Model(input=inputs, output=parameters)
optimizer = rmsprop(0.001,clipnorm=10,clipvalue=10)
model.compile(optimizer=optimizer,
loss=neg_beta_mixture_likelihood)
# fit model
history = model.fit(Xtr,ytr,nb_epoch=25,batch_size=32,verbose=2)
# plot model predictions vs truth
output = model.predict(Xts)
preds, uncertainties = mdnOutputToPreds(output,mode=0)
plt.scatter(yts,preds); plt.xlabel('yts true'); plt.ylabel('yts prediction')
errors = np.abs(preds - yts)
print "Correlation between uncertainty and error:", np.corrcoef(errors,uncertainties)[0,1]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment