Last active
July 21, 2020 22:08
-
-
Save sergeyf/cf20b2759a7d38035f30384769bed9df to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import matplotlib.pyplot as plt | |
import numpy as np | |
import seaborn | |
from keras.layers import Input, Dense, merge, ELU, Dropout | |
from keras.models import Model | |
from keras.regularizers import l2 | |
from keras import backend as K | |
from keras.optimizers import rmsprop, adam | |
# From Bishop's paper - to get point estimates | |
def mdnOutputToPreds(output,mode=0): | |
m = output.shape[-1] / 3 | |
alphas = output[:, 0 * m: 1 * m] | |
betas = output[:, 1 * m: 2 * m] | |
pis = output[:, 2 * m: 3 * m] | |
means = alphas/(alphas + betas) | |
means[alphas==0] = 0 | |
sigmas = np.sqrt( alphas*betas / ( (alphas+betas)**2 * (alphas + betas + 1) ) ) | |
if mode == 0: | |
max_components = np.argmax(pis/sigmas,axis=1) | |
elif mode == 1: | |
max_components = np.argmax(pis,axis=1) | |
return means[np.arange(len(means)),max_components], sigmas[np.arange(len(means)),max_components] | |
# Bishop's exponential activation results in explosively large sigmas | |
def safeBeta(x): | |
pos = K.relu(x) | |
neg = (x - K.abs(x)) * 0.5 | |
return K.clip(pos + K.exp(neg), 1e-6, 1000) | |
# | |
def gammaln(x): | |
# fast approximate gammaln from Paul Mineiro | |
# http://www.machinedlearnings.com/2011/06/faster-lda.html | |
logterm = K.log(x * (1.0 + x) * (2.0 + x)) | |
xp3 = 3.0 + x | |
return -2.081061466 - x + 0.0833333 / xp3 - logterm + (2.5 + x) * K.log(xp3) | |
# negative log likelihood loss for a mixture of betas | |
def neg_beta_mixture_likelihood(true, parameters): | |
m = K.shape(parameters)[-1] // 3 | |
alphas = parameters[:, 0 * m: 1 * m] | |
betas = parameters[:, 1 * m: 2 * m] | |
pis = parameters[:, 2 * m: 3 * m] | |
true_repeated = K.repeat_elements(K.clip(true,1e-6,1-1e-6), m, axis=-1) | |
d1 = (alphas - 1.0) * K.log(true_repeated) | |
d2 = (betas - 1.0) * K.log(1.0 - true_repeated) | |
f1 = d1 + d2 | |
f2 = gammaln(alphas) | |
f3 = gammaln(betas) | |
f4 = gammaln(alphas + betas) | |
exponent = f1 + f4 - f2 - f3 | |
max_exponent = K.max(exponent, axis=-1, keepdims=True) | |
max_exponent_repeated = K.repeat_elements(max_exponent, m, axis=-1) | |
likelihood = pis * K.exp( exponent - max_exponent_repeated ) | |
return K.mean( -(K.log(K.sum(likelihood,axis=-1)) + max_exponent), axis=-1 ) | |
# multiple output function | |
def f_of_x(X,w): | |
n,d = X.shape | |
X_dot_w = np.dot(X,w) | |
y = np.zeros(n) | |
# the inner product randomly goes through a sin | |
# or a cos | |
cos_flag = np.random.randn(n) < 0.0 | |
sin_flag = ~cos_flag | |
y[cos_flag] = np.cos(X_dot_w[cos_flag]) | |
y[sin_flag] = np.sin(X_dot_w[sin_flag]) | |
return y | |
# generate some simulated data | |
d = 10 | |
ntr = 100000 | |
nts = 10000 | |
w = np.random.rand(d) | |
Xtr = np.random.randn(ntr,d) | |
ytr = f_of_x(Xtr,w) | |
Xts = np.random.randn(nts,d) | |
yts = f_of_x(Xts,w) | |
# make sure everything is between 0 and 1 | |
ymin,ymax = ytr.min(),ytr.max() | |
ytr = (ytr - ymin)/(ymax - ymin) | |
yts = (yts - ymin)/(ymax - ymin) | |
# network architecture | |
m = 50 # number of components | |
lam = 1e-6 # l2 regularizer for each layer | |
dropout_p = 0.3 | |
inputs = Input(shape=(d,)) | |
mlp = Dense(100, W_regularizer=l2(lam))(inputs) | |
mlp = ELU()(mlp) | |
mlp = Dropout(dropout_p)(mlp) | |
mlp = Dense(50, W_regularizer=l2(lam))(mlp) | |
mlp = ELU()(mlp) | |
alphas = Dense(m, W_regularizer=l2(lam), activation=safeBeta)(mlp) | |
betas = Dense(m, W_regularizer=l2(lam), activation=safeBeta)(mlp) | |
pis = Dense(m, W_regularizer=l2(lam), activation='softmax')(mlp) | |
parameters = merge([alphas,betas,pis],mode='concat') | |
model = Model(input=inputs, output=parameters) | |
optimizer = rmsprop(0.001,clipnorm=10,clipvalue=10) | |
model.compile(optimizer=optimizer, | |
loss=neg_beta_mixture_likelihood) | |
# fit model | |
history = model.fit(Xtr,ytr,nb_epoch=25,batch_size=32,verbose=2) | |
# plot model predictions vs truth | |
output = model.predict(Xts) | |
preds, uncertainties = mdnOutputToPreds(output,mode=0) | |
plt.scatter(yts,preds); plt.xlabel('yts true'); plt.ylabel('yts prediction') | |
errors = np.abs(preds - yts) | |
print "Correlation between uncertainty and error:", np.corrcoef(errors,uncertainties)[0,1] |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment