Created
April 30, 2021 18:51
-
-
Save jparkhill/34851d8c3cfe63eff241891f9eadbb70 to your computer and use it in GitHub Desktop.
Two Ansatze modules for Pytorch PDEs.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class Neural_Density(torch.nn.Module): | |
""" | |
A neural model of a time dependent probability density on | |
a vector valued state-space. | |
ie: rho(t,{x_0, x_1, ... x_{state_dim}}) | |
for now, I'm not even enforcing normalization. | |
could with a gaussian mixture or whatever. | |
""" | |
def __init__(self, state_dim, hidden_dim = 64): | |
super(Neural_Density, self).__init__() | |
self.input_dim = state_dim+1 # self.register_buffer('input_dim',state_dim+1) | |
self.state_dim = state_dim | |
self.net = tch.nn.Sequential( | |
tch.nn.Linear(self.input_dim, hidden_dim), | |
tch.nn.Softplus(), | |
tch.nn.Linear(hidden_dim, 1), | |
tch.nn.Softplus(), # density is positive. | |
) | |
def forward(self,t,x): | |
# Just evaluate the probability at the argument. | |
return self.net(tch.cat([t.unsqueeze(-1),x],-1)).squeeze() | |
class Reshape(tch.nn.Module): | |
def __init__(self, shp): | |
super(Reshape, self).__init__() | |
self.shape = shp | |
def forward(self, x): | |
return x.view(self.shape) | |
class Gaussian_Mixture_Density(tch.nn.Module): | |
def __init__(self, state_dim, | |
m_dim=1, | |
hidden_dim = 16, | |
): | |
""" | |
A network which parameterically | |
produces gaussian output with feed-forwards | |
that parameterize the mixture. | |
""" | |
super(Gaussian_Mixture_Density, self).__init__() | |
# Rho(x,y) is the density parameterized by t | |
input_dim=1 | |
output_dim=state_dim | |
self.output_dim = output_dim | |
self.m_dim = m_dim | |
mixture_dim = output_dim*m_dim | |
self.n_corr = int((self.output_dim*(self.output_dim-1)/2.)) | |
self.sftpls = tch.nn.Softplus() | |
self.sftmx = tch.nn.Softmax(dim=-1) | |
self.corr_net = tch.nn.Sequential( | |
# tch.nn.Dropout(0.1), | |
tch.nn.Linear(input_dim, hidden_dim), | |
tch.nn.Tanh(), | |
tch.nn.Linear(hidden_dim, self.n_corr*m_dim), | |
Reshape((-1, m_dim, self.n_corr)) | |
) | |
self.std_net = tch.nn.Sequential( | |
# tch.nn.Dropout(0.1), | |
tch.nn.Linear(input_dim, hidden_dim), | |
tch.nn.SELU(), | |
tch.nn.Linear(hidden_dim, mixture_dim), | |
tch.nn.Softplus(10.), | |
Reshape((-1, m_dim, self.output_dim)) | |
) | |
self.mu_net = tch.nn.Sequential( | |
# tch.nn.Dropout(0.1), | |
tch.nn.Linear(input_dim, hidden_dim), | |
tch.nn.Tanh(), | |
tch.nn.Linear(hidden_dim, mixture_dim), | |
Reshape((-1, m_dim, self.output_dim)) | |
) | |
self.pi_net = tch.nn.Sequential( | |
# tch.nn.Dropout(0.1), | |
tch.nn.Linear(input_dim, hidden_dim), | |
tch.nn.SELU(), | |
tch.nn.Linear(hidden_dim, m_dim), | |
tch.nn.Tanh(), | |
tch.nn.Softmax(dim=-1) | |
) | |
super(Gaussian_Mixture_Density, self).add_module("corr_net",self.corr_net) | |
super(Gaussian_Mixture_Density, self).add_module("std_net",self.std_net) | |
super(Gaussian_Mixture_Density, self).add_module("mu_net",self.mu_net) | |
super(Gaussian_Mixture_Density, self).add_module("pi_net",self.pi_net) | |
def pi(self, x): | |
return self.pi_net(x) | |
def mu(self, x): | |
return self.mu_net(x) | |
def L(self, x): | |
""" | |
Constructs the lower diag cholesky decomposed sigma matrix. | |
""" | |
batch_size = x.shape[0] | |
L = tch.zeros(batch_size, self.m_dim, self.output_dim, self.output_dim) | |
b_inds = tch.arange(batch_size).unsqueeze(1).unsqueeze(1).repeat(1, self.m_dim, self.output_dim).flatten() | |
m_inds = tch.arange(self.m_dim).unsqueeze(1).unsqueeze(0).repeat(batch_size, 1, self.output_dim).flatten() | |
s_inds = tch.arange(self.output_dim).unsqueeze(0).unsqueeze(0).repeat(batch_size, self.m_dim,1).flatten() | |
L[b_inds, m_inds, s_inds, s_inds] = self.std_net(x).flatten() | |
if self.output_dim>1: | |
t_inds = tch.tril_indices(self.output_dim,self.output_dim,-1) | |
txs = t_inds[0].flatten() | |
tys = t_inds[1].flatten() | |
bb_inds = tch.arange(batch_size).unsqueeze(1).unsqueeze(1).repeat(1, self.m_dim, txs.shape[0]).flatten() | |
mt_inds = tch.arange(self.m_dim).unsqueeze(1).unsqueeze(0).repeat(batch_size, 1, txs.shape[0]).flatten() | |
xt_inds = txs.unsqueeze(0).unsqueeze(0).repeat(batch_size, self.m_dim, 1).flatten() | |
yt_inds = tys.unsqueeze(0).unsqueeze(0).repeat(batch_size, self.m_dim, 1).flatten() | |
L[bb_inds, mt_inds, xt_inds, yt_inds] = self.corr_net(x).flatten() | |
return L | |
def get_distribution(self, x): | |
pi_distribution = tch.distributions.Categorical(self.pi(x)) | |
GMM = tch.distributions.mixture_same_family.MixtureSameFamily(pi_distribution, | |
tch.distributions.MultivariateNormal(self.mu(x), | |
scale_tril=self.L(x))) | |
return GMM | |
def forward(self, t, x): | |
return self.get_distribution(t.unsqueeze(-1)).log_prob(x).exp() | |
def rsample(self, t, sample_shape = 128): | |
""" | |
returns samples from the gaussian mixture (samples are added last dimension) | |
ie: batch X dim X samp | |
""" | |
samps_ = self.get_distribution(t).sample(sample_shape=[sample_shape]) | |
samps = samps_.permute(1,2,0) | |
return samps | |
def mean(self,t): | |
return self.get_distribution(t.unsqueeze(-1)).mean | |
def std(self,t): | |
return tch.sqrt(self.get_distribution(t.unsqueeze(-1)).variance) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment