Created
June 10, 2022 19:42
-
-
Save shaabhishek/005f5a5c8ca7363331953e6f073f4df6 to your computer and use it in GitHub Desktop.
Useful Custom Pytorch Distributions
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class InverseGamma(TransformedDistribution): | |
""" | |
https://en.wikipedia.org/wiki/Inverse-gamma_distribution | |
Creates an inverse-gamma distribution parameterized by | |
`concentration` and `rate`. | |
X ~ Gamma(concentration, rate) | |
Y = 1/X ~ InverseGamma(concentration, rate) | |
:param torch.Tensor concentration: the concentration parameter (i.e. alpha on wikipedia). | |
:param torch.Tensor rate: the rate parameter (i.e. beta on wikipedia). | |
""" | |
arg_constraints = {"concentration": constraints.positive, "rate": constraints.positive} | |
support = constraints.positive | |
has_rsample = True | |
def __init__(self, concentration, rate, validate_args=None): | |
base_distribution = Gamma(concentration, rate) | |
super().__init__(base_distribution, PowerTransform(-torch.ones_like(concentration)), validate_args=validate_args) | |
def expand(self, batch_shape, _instance=None): | |
new = self._get_checked_instance(InverseGamma, _instance) | |
return super().expand(batch_shape, _instance=new) | |
def entropy(self): | |
""" | |
https://en.wikipedia.org/wiki/Inverse-gamma_distribution | |
alpha + log(beta * Gamma(alpha)) - (alpha + 1) * Digamma(alpha) | |
""" | |
return (self.concentration + torch.log(self.rate) + torch.lgamma(self.concentration) - (1.0 + self.concentration) * torch.digamma(self.concentration)) | |
@property | |
def concentration(self): | |
return self.base_dist.concentration | |
@property | |
def rate(self): | |
return self.base_dist.rate | |
class NormalGamma(Distribution): | |
""" | |
https://en.wikipedia.org/wiki/Normal-gamma_distribution | |
Creates an normal-gamma distribution parameterized by | |
`mu`, `lambd`, `concentration` and `rate`. | |
P ~ Gamma(concentration, rate) | |
X ~ Normal(mu, variance=(lambd*P)^-1) | |
=> (X,P) ~ NormalGamma(mu, lambd, concentration, rate) | |
:param torch.Tensor mu: the mean parameter for the normal distribution (i.e. mu on wikipedia). | |
:param torch.Tensor lambd: the scaling of precision for the normal distribution (i.e. lambda on wikipedia). | |
:param torch.Tensor concentration: the concentration parameter for the gamma distribution (i.e. alpha on wikipedia). | |
:param torch.Tensor rate: the rate parameter for the gamma distribution (i.e. beta on wikipedia). | |
""" | |
arg_constraints = {"mu": constraints.real, "lambd": constraints.positive, "concentration": constraints.positive, "rate": constraints.positive} | |
support = constraints.positive #TODO | |
has_rsample = True | |
def __init__(self, mu, lambd, concentration, rate, validate_args=None): | |
self._gamma = Gamma(concentration, rate, validate_args=validate_args) | |
self._mu = mu | |
self._lambd = lambd | |
batch_shape = self.mu.size() | |
event_shape = torch.Size([2]) | |
super().__init__(batch_shape, event_shape=event_shape, validate_args=validate_args) | |
def expand(self, batch_shape, _instance=None): | |
# new = self._get_checked_instance(NormalGamma, _instance) | |
# return super().expand(batch_shape, _instance=new) | |
new = self._get_checked_instance(NormalGamma, _instance) | |
batch_shape = torch.Size(batch_shape) | |
new._gamma.concentration = self.concentration.expand(batch_shape) | |
new._gamma.rate = self.rate.expand(batch_shape) | |
new.lambd = self.lambd.expand(batch_shape) | |
new._mu = self.mu.expand(batch_shape) | |
super(NormalGamma, new).__init__(batch_shape, event_shape=self._event_shape, validate_args=False) | |
new._validate_args = self._validate_args | |
return new | |
def sample(self, sample_shape=()): | |
precision = self._gamma.sample(sample_shape) | |
mu = Normal(self.mu, (self.lambd*precision).sqrt().reciprocal()).sample() | |
return torch.stack([mu, precision], dim=-1) | |
def rsample(self, sample_shape=()): | |
precision = self._gamma.rsample(sample_shape) | |
mu = Normal(self.mu, (self.lambd*precision).sqrt().reciprocal()).rsample() | |
return torch.stack([mu, precision], dim=-1) | |
def entropy(self): | |
""" | |
https://en.wikipedia.org/wiki/Normal-gamma_distribution | |
alpha + log(beta * Gamma(alpha)) - (alpha + 1) * Digamma(alpha) | |
""" | |
return (self.concentration + torch.log(self.rate) + torch.lgamma(self.concentration) - (1.0 + self.concentration) * torch.digamma(self.concentration)) | |
@property | |
def concentration(self): | |
return self._gamma.concentration | |
@property | |
def rate(self): | |
return self._gamma.rate | |
@property | |
def lambd(self): | |
return self._lambd | |
@property | |
def mu(self): | |
return self._mu | |
@property | |
def mean(self): | |
return torch.stack([self.mu, self._gamma.mean], dim=-1) | |
@property | |
def variance(self): | |
variance_mean = self.rate / (self.lambd * (self.concentration - 1)) | |
variance_precision = self._gamma.variance | |
return torch.stack([variance_mean, variance_precision], dim=-1) | |
def log_prob(self, value): | |
value = torch.as_tensor(value, dtype=self.mu.dtype, device=self.mu.device) | |
mean = value[..., 0] | |
precision = value[..., 1] | |
sq_dist = (mean - self.mu) ** 2 | |
return (self.concentration * torch.log(self.rate) + | |
0.5 * torch.log(self.lambd) + | |
(self.concentration - 0.5) * torch.log(precision) - | |
self.rate * precision - | |
0.5 * self.lambd * precision * sq_dist - | |
torch.lgamma(self.concentration) - | |
math.log(math.sqrt(2 * math.pi))) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment