Created
September 12, 2019 13:07
-
-
Save WenchaoDing/0f6539688715c568960075f77caa9ad3 to your computer and use it in GitHub Desktop.
simp model
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
''' | |
# Filename : model.py | |
# Author : Wenchao Ding | |
# Date : 2018-06-25 | |
''' | |
import math | |
import matplotlib.pyplot as plt | |
import numpy as np | |
import torch | |
import torch.nn as nn | |
from torch.distributions import Categorical | |
ONEOVERSQRT2PI = 1.0 / math.sqrt(2*math.pi) | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
nllloss = nn.NLLLoss().to(device) | |
class SIMP(nn.Module): | |
def __init__(self, num_mdns, in_features, out_features, num_gaussians): | |
super(SIMP, self).__init__() | |
self.hidden_size = 400 | |
self.simp_embed = nn.Sequential( | |
nn.Linear(in_features, 400), | |
nn.Tanh(), | |
nn.Linear(400, 400), | |
nn.Tanh(), | |
nn.Linear(400, 400), | |
nn.Tanh(), | |
nn.Linear(400, self.hidden_size), | |
nn.Tanh(), | |
nn.Dropout(p=0.5)) | |
self.area = nn.Sequential( | |
nn.Linear(self.hidden_size, num_mdns), | |
nn.LogSoftmax(dim=1)) | |
self.mdns = nn.ModuleList([MDN(self.hidden_size, out_features, num_gaussians).to(device) for i in range(num_mdns)]) | |
def forward(self, minibatch): | |
embeds = self.simp_embed(minibatch) | |
area_score = self.area(embeds) | |
mdn_inferences = [] | |
for idx in range(len(self.mdns)): | |
pi, sigma, mu = self.mdns[idx](embeds) | |
mdn_inferences += [pi, sigma, mu] | |
return area_score, mdn_inferences | |
def simp_loss(area_score, mdn_inferences, target): | |
num_mdns = area_score.size()[1] | |
loss = torch.zeros(num_mdns).to(device) | |
for idx in range(num_mdns): | |
loss[idx] = mdn_loss_with_mask(mdn_inferences[idx*3], mdn_inferences[idx*3+1], | |
mdn_inferences[idx*3+2], target, select_label=idx+1) | |
mdn_loss = torch.mean(loss) | |
area_target = torch.tensor(target[:,0]-1).long().to(device) | |
area_loss = nllloss(area_score, area_target) | |
return area_loss, mdn_loss | |
class MDN(nn.Module): | |
""" | |
(pi, sigma, mu) (BxG, BxGxO, BxGxO): B is the batch size, G is the | |
number of Gaussians, and O is the number of dimensions for each | |
Gaussian. Pi is a multinomial distribution of the Gaussians. Sigma | |
is the standard deviation of each Gaussian. Mu is the mean of each | |
Gaussian. | |
""" | |
def __init__(self, in_features, out_features, num_gaussians): | |
super(MDN, self).__init__() | |
self.in_features = in_features | |
self.out_features = out_features | |
self.num_gaussians = num_gaussians | |
self.pi = nn.Sequential( | |
nn.Linear(in_features, num_gaussians), | |
nn.LogSoftmax(dim=1)) | |
self.sigma = nn.Linear(in_features, out_features*num_gaussians) | |
self.mu = nn.Linear(in_features, out_features*num_gaussians) | |
def forward(self, minibatch): | |
# pi (batch_size x num_gaussians) | |
pi = torch.exp(self.pi(minibatch)) | |
# self.sigma (in_features, out_features*num_gaussians) | |
# sigma: (batch_size x out_features*num_gaussians) | |
sigma = torch.exp(self.sigma(minibatch)) | |
sigma = sigma.view(-1, self.num_gaussians, self.out_features) | |
mu = self.mu(minibatch) | |
if torch.sum(torch.isnan(mu))>0: | |
print('input', minibatch) | |
raise ValueError('weight overflow') | |
mu = mu.view(-1, self.num_gaussians, self.out_features) | |
return pi, sigma, mu | |
def gaussian_probability(sigma, mu, data): | |
"""Returns the probability of `data` given MoG parameters `sigma` and `mu`. | |
Arguments: | |
sigma (BxGxO): The standard deviation of the Gaussians. B is the batch | |
size, G is the number of Gaussians, and O is the number of | |
dimensions per Gaussian. | |
mu (BxGxO): The means of the Gaussians. B is the batch size, G is the | |
number of Gaussians, and O is the number of dimensions per Gaussian. | |
data (BxI): A batch of data. B is the batch size and I is the number of | |
input dimensions. | |
Returns: | |
probabilities (BxG): The probability of each point in the probability | |
of the distribution in the corresponding sigma/mu index. | |
""" | |
data = data.unsqueeze(1).expand_as(sigma) | |
ret = ONEOVERSQRT2PI * torch.exp(-0.5 * ((data - mu) / sigma)**2) / sigma | |
return torch.prod(ret, 2) | |
def mdn_loss(pi, sigma, mu, target): | |
"""Calculates the error, given the MoG parameters and the target | |
The loss is the negative log likelihood of the data given the MoG | |
parameters. | |
""" | |
prob = pi * gaussian_probability(sigma, mu, target) | |
nll = -torch.log(torch.sum(prob, dim=1)) | |
return torch.mean(nll) | |
def mdn_loss_with_mask(pi, sigma, mu, target, select_label): | |
n = pi.size()[0] | |
mask_indices = [] | |
for i in range(n): | |
if target[i][0] == select_label: | |
mask_indices.append(i) | |
if len(mask_indices) == 0: | |
return torch.zeros(1).to(device) | |
indice_tensor = torch.from_numpy(np.array(mask_indices)).long().to(device) | |
# pi(BxG) sigma(BxGxO) mu(BxG) | |
pi_select = torch.index_select(pi, 0, indice_tensor) | |
sigma_select = torch.index_select(sigma, 0, indice_tensor) | |
mu_select = torch.index_select(mu, 0, indice_tensor) | |
target_select = torch.index_select(target, 0, indice_tensor) | |
prob = gaussian_probability(sigma_select, mu_select, target_select[:,1:]) | |
prob = torch.sum(pi_select*prob, dim=1) | |
nll = -torch.log(prob+1e-10) | |
return torch.sum(nll)/n | |
def sample(pi, sigma, mu): | |
"""Draw samples from a MoG. | |
""" | |
# categorical = Categorical(pi) | |
# pis = list(categorical.sample().data) | |
# sample = sigma.data.new(sigma.size(0), sigma.size(2)).normal_() | |
# for i, mode in enumerate(pis): | |
# sample[i] = sample[i].mul(sigma[i,mode]).add(mu[i,mode]) | |
# return sample | |
N, K = pi.shape | |
_, K, O = mu.shape | |
out = torch.zeros(N, O) | |
for i in range(N): | |
# pi must sum to 1, thus we can sample from a uniform | |
# distribution, then transform that to select the component | |
u = np.random.uniform() # sample from [0, 1) | |
# split [0, 1] into k segments: [0, pi[0]), [pi[0], pi[1]), ..., [pi[K-1], pi[K]) | |
# then determine the segment `u` that falls into and sample from that component | |
prob_sum = 0 | |
for k in range(K): | |
prob_sum += pi.data[i, k] | |
if u < prob_sum: | |
# sample from the kth component | |
for o in range(O): | |
sample = np.random.normal(mu.data[i, k, o], sigma.data[i, k, o]) | |
out[i, o] = sample | |
break | |
return out |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment