Skip to content

Instantly share code, notes, and snippets.

@shaabhishek
Created June 10, 2022 19:42
Show Gist options
  • Save shaabhishek/005f5a5c8ca7363331953e6f073f4df6 to your computer and use it in GitHub Desktop.
Save shaabhishek/005f5a5c8ca7363331953e6f073f4df6 to your computer and use it in GitHub Desktop.
Useful Custom Pytorch Distributions
class InverseGamma(TransformedDistribution):
"""
https://en.wikipedia.org/wiki/Inverse-gamma_distribution
Creates an inverse-gamma distribution parameterized by
`concentration` and `rate`.
X ~ Gamma(concentration, rate)
Y = 1/X ~ InverseGamma(concentration, rate)
:param torch.Tensor concentration: the concentration parameter (i.e. alpha on wikipedia).
:param torch.Tensor rate: the rate parameter (i.e. beta on wikipedia).
"""
arg_constraints = {"concentration": constraints.positive, "rate": constraints.positive}
support = constraints.positive
has_rsample = True
def __init__(self, concentration, rate, validate_args=None):
base_distribution = Gamma(concentration, rate)
super().__init__(base_distribution, PowerTransform(-torch.ones_like(concentration)), validate_args=validate_args)
def expand(self, batch_shape, _instance=None):
new = self._get_checked_instance(InverseGamma, _instance)
return super().expand(batch_shape, _instance=new)
def entropy(self):
"""
https://en.wikipedia.org/wiki/Inverse-gamma_distribution
alpha + log(beta * Gamma(alpha)) - (alpha + 1) * Digamma(alpha)
"""
return (self.concentration + torch.log(self.rate) + torch.lgamma(self.concentration) - (1.0 + self.concentration) * torch.digamma(self.concentration))
@property
def concentration(self):
return self.base_dist.concentration
@property
def rate(self):
return self.base_dist.rate
class NormalGamma(Distribution):
"""
https://en.wikipedia.org/wiki/Normal-gamma_distribution
Creates an normal-gamma distribution parameterized by
`mu`, `lambd`, `concentration` and `rate`.
P ~ Gamma(concentration, rate)
X ~ Normal(mu, variance=(lambd*P)^-1)
=> (X,P) ~ NormalGamma(mu, lambd, concentration, rate)
:param torch.Tensor mu: the mean parameter for the normal distribution (i.e. mu on wikipedia).
:param torch.Tensor lambd: the scaling of precision for the normal distribution (i.e. lambda on wikipedia).
:param torch.Tensor concentration: the concentration parameter for the gamma distribution (i.e. alpha on wikipedia).
:param torch.Tensor rate: the rate parameter for the gamma distribution (i.e. beta on wikipedia).
"""
arg_constraints = {"mu": constraints.real, "lambd": constraints.positive, "concentration": constraints.positive, "rate": constraints.positive}
support = constraints.positive #TODO
has_rsample = True
def __init__(self, mu, lambd, concentration, rate, validate_args=None):
self._gamma = Gamma(concentration, rate, validate_args=validate_args)
self._mu = mu
self._lambd = lambd
batch_shape = self.mu.size()
event_shape = torch.Size([2])
super().__init__(batch_shape, event_shape=event_shape, validate_args=validate_args)
def expand(self, batch_shape, _instance=None):
# new = self._get_checked_instance(NormalGamma, _instance)
# return super().expand(batch_shape, _instance=new)
new = self._get_checked_instance(NormalGamma, _instance)
batch_shape = torch.Size(batch_shape)
new._gamma.concentration = self.concentration.expand(batch_shape)
new._gamma.rate = self.rate.expand(batch_shape)
new.lambd = self.lambd.expand(batch_shape)
new._mu = self.mu.expand(batch_shape)
super(NormalGamma, new).__init__(batch_shape, event_shape=self._event_shape, validate_args=False)
new._validate_args = self._validate_args
return new
def sample(self, sample_shape=()):
precision = self._gamma.sample(sample_shape)
mu = Normal(self.mu, (self.lambd*precision).sqrt().reciprocal()).sample()
return torch.stack([mu, precision], dim=-1)
def rsample(self, sample_shape=()):
precision = self._gamma.rsample(sample_shape)
mu = Normal(self.mu, (self.lambd*precision).sqrt().reciprocal()).rsample()
return torch.stack([mu, precision], dim=-1)
def entropy(self):
"""
https://en.wikipedia.org/wiki/Normal-gamma_distribution
alpha + log(beta * Gamma(alpha)) - (alpha + 1) * Digamma(alpha)
"""
return (self.concentration + torch.log(self.rate) + torch.lgamma(self.concentration) - (1.0 + self.concentration) * torch.digamma(self.concentration))
@property
def concentration(self):
return self._gamma.concentration
@property
def rate(self):
return self._gamma.rate
@property
def lambd(self):
return self._lambd
@property
def mu(self):
return self._mu
@property
def mean(self):
return torch.stack([self.mu, self._gamma.mean], dim=-1)
@property
def variance(self):
variance_mean = self.rate / (self.lambd * (self.concentration - 1))
variance_precision = self._gamma.variance
return torch.stack([variance_mean, variance_precision], dim=-1)
def log_prob(self, value):
value = torch.as_tensor(value, dtype=self.mu.dtype, device=self.mu.device)
mean = value[..., 0]
precision = value[..., 1]
sq_dist = (mean - self.mu) ** 2
return (self.concentration * torch.log(self.rate) +
0.5 * torch.log(self.lambd) +
(self.concentration - 0.5) * torch.log(precision) -
self.rate * precision -
0.5 * self.lambd * precision * sq_dist -
torch.lgamma(self.concentration) -
math.log(math.sqrt(2 * math.pi)))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment