Created
August 31, 2016 03:20
-
-
Save mujjingun/7403363ff0249e5b8bbe9d5490e5da80 to your computer and use it in GitHub Desktop.
Mixture Density Network(MDN) implemenation in Theano
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# -*- coding: utf-8 -*- | |
""" | |
Created on Mon Aug 22 00:36:48 2016 | |
@author: park | |
""" | |
#%% train | |
import theano | |
import theano.tensor as T | |
import numpy as np | |
floatX = theano.config.floatX = 'float32' | |
floatX_cast = np.float32 | |
# similar to rmsprop | |
def adam(cost, params, lr=0.0002, b1=0.1, b2=0.001, e=1e-8): | |
updates = [] | |
grads = T.grad(cost, params) | |
i = theano.shared(np.zeros(1, dtype=floatX)[0]) | |
i_t = i + 1. | |
fix1 = 1. - (1. - b1) ** i_t | |
fix2 = 1. - (1. - b2) ** i_t | |
lr_t = lr * (T.sqrt(fix2) / fix1) | |
for p, g in zip(params, grads): | |
m = theano.shared(np.zeros_like(p.get_value(), dtype=floatX)) | |
v = theano.shared(np.zeros_like(p.get_value(), dtype=floatX)) | |
m_t = (b1 * g) + ((1. - b1) * m) | |
v_t = (b2 * T.sqr(g)) + ((1. - b2) * v) | |
g_t = m_t / (T.sqrt(v_t) + e) | |
p_t = p - (lr_t * g_t) | |
updates.append((m, m_t)) | |
updates.append((v, v_t)) | |
updates.append((p, p_t)) | |
updates.append((i, i_t)) | |
return updates | |
class Layer(object): | |
def __init__(self, n_in, n_out): | |
init_size = np.sqrt(6 / (n_in + n_out)) | |
self.W = theano.shared( | |
value=np.asarray( | |
np.random.uniform(-init_size, init_size, (n_in, n_out)), | |
dtype=floatX | |
), | |
name='Wxy', | |
borrow=True | |
) | |
self.b = theano.shared( | |
value=np.zeros( | |
(n_out,), | |
dtype=floatX | |
) + 0.1, | |
name='by', | |
borrow=True | |
) | |
# parameters of the model | |
self.params = [self.W, self.b] | |
def get(self, x): | |
# raw output | |
out = T.dot(x, self.W) + self.b | |
return out | |
def unpack(y, N, M, b): | |
"""Unpacks the NN output to mixture model parameters. | |
y = (minibatch_size, (N + 2) * M) | |
N = output dimension | |
M = number of mixture components | |
b = bias(scalar) | |
""" | |
components = y.reshape((-1, N + 2, M)) | |
mu = components[:, :N, :] | |
sigma = T.exp(components[:, N, :] - b) | |
mixing = T.nnet.softmax(components[:, N + 1, :] * (1 + b)) | |
return [mu, sigma, mixing] | |
def logsumexp(x, axis=None): | |
epsilon = 0.00001 | |
x_max = T.max(x, axis=axis, keepdims=True) | |
return T.log(T.sum(T.exp(x - x_max), axis=axis, keepdims=True) + epsilon) + x_max | |
def NLL(mu, sigma, mixing, y): | |
"""Computes the mean of negative log likelihood for P(y|x) | |
y = T.matrix('y') # (minibatch_size, output_size) | |
mu = T.tensor3('mu') # (minibatch_size, output_size, n_components) | |
sigma = T.matrix('sigma') # (minibatch_size, n_components) | |
mixing = T.matrix('mixing') # (minibatch_size, n_components) | |
""" | |
# multivariate Gaussian | |
exponent = -0.5 * T.inv(sigma) * T.sum((y.dimshuffle(0,1,'x') - mu)**2, axis=1) \ | |
+ T.log(mixing) - .5 * y.shape[1].astype(floatX) * T.log(2 * np.pi * sigma) | |
log_gauss = logsumexp(exponent, axis=1) | |
# batch | |
res = -T.mean(log_gauss) | |
return res | |
x = T.matrix('x') | |
y = T.matrix('y') | |
# sampling bias | |
b = T.scalar('b') | |
N = 2 | |
M = 20 | |
out_size = (N + 2) * M | |
layer = Layer(1, out_size) | |
layer2 = Layer(out_size, out_size) | |
params = layer.params + layer2.params | |
# Contruct network with 2 layers | |
x1 = layer.get(x) | |
out = layer2.get(T.nnet.relu(x1)) | |
[mu, sigma, mixing] = unpack(out, N, M, b) | |
nll = NLL(mu, sigma, mixing, y) | |
updates = adam(nll, params) | |
train = theano.function( | |
inputs=[x, y], | |
outputs=[nll], | |
updates=updates, | |
givens={b: floatX_cast(0)} | |
) | |
predict = theano.function( | |
inputs=[x, b], | |
outputs=[mu, sigma, mixing] | |
) | |
#%% train | |
l = None | |
for i in range(100000): | |
if i % 100 == 0: | |
x = np.arange(-1, 1, 0.08) | |
y = np.arange(-1, 1, 0.08) | |
x, y = np.meshgrid(x, y) | |
x = x.flatten() | |
y = y.flatten() | |
z = x**2 + y**2 | |
x = np.append(x, x).flatten() | |
y = np.append(y, y).flatten() | |
z = np.append(z, -z).flatten() | |
x += np.random.randn(*x.shape) * 0.001 | |
y += np.random.randn(*y.shape) * 0.001 | |
z += np.random.randn(*z.shape) * 0.001 | |
x = np.expand_dims(x, axis=1).astype(floatX) | |
y = np.stack([y, z]).transpose().astype(floatX) | |
selection = np.random.choice(x.shape[0], 20) | |
sx, sy = x[selection], y[selection] | |
cl = train(sx, sy)[0] | |
if l is None: l = cl | |
else: l = l * 0.999 + cl * 0.001 | |
if i % 500 == 0: print(l) | |
#%% predict | |
from mpl_toolkits.mplot3d import Axes3D | |
import matplotlib.pyplot as plt | |
import matplotlib as mpl | |
from matplotlib.colors import ListedColormap | |
x = np.arange(-1, 1, 0.08) | |
y = np.arange(-1, 1, 0.08) | |
x, y = np.meshgrid(x, y) | |
z = x**2 + y**2 | |
fig = plt.figure() | |
ax = fig.add_subplot(111, projection='3d') | |
ax.plot_wireframe(x, y, z, alpha=0.2) | |
ax.plot_wireframe(x, y, -z, alpha=0.2) | |
x = np.arange(-1, 1, 0.005) | |
x = np.expand_dims(x, axis=1).astype(floatX) | |
[mu, sigma, mixing] = predict(x, 0) | |
# Get the colormap colors | |
my_cmap = np.zeros((256, 4)) | |
# Set alpha | |
my_cmap[:,-1] = np.linspace(0, 1, 256) | |
# Create new colormap | |
my_cmap = ListedColormap(my_cmap) | |
# plot means | |
norm = mpl.colors.Normalize(vmin=0.,vmax=1.) | |
for i in range(M): | |
my, mz = mu[:,0,i], mu[:,1,i] | |
mix = mixing[:, i] | |
size = np.log(1 / sigma[:, i]) * 5 | |
ax.scatter(x, my, mz, s=size, lw=0, c=mix, norm=norm, cmap=my_cmap) | |
def sample(mu, sigma, mix): | |
"""Sample from a mixture density model | |
mu (ndim, M) | |
sigma (M) | |
mix (M) | |
""" | |
a = np.random.choice(mix.size, p=mix) | |
ndim = mu.shape[0] | |
cov = np.diag(np.full((ndim,), sigma[a].astype(float))) | |
s = np.random.multivariate_normal(mu[:,a], cov) | |
return s | |
# plot sampled points | |
z = np.zeros((x.size, N)) | |
for i in range(x.size): | |
z[i] = sample(mu[i], sigma[i], mixing[i]) | |
plt.plot(x, z[:,0], z[:,1], 'ro', lw=0, alpha=0.5) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment