Created
June 4, 2020 14:26
-
-
Save mbjoseph/86b621d5c286527c76d1ea8519ff67c7 to your computer and use it in GitHub Desktop.
Zero one inflated beta distribution in PyTorch
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import torch\n", | |
"from torch.distributions import Beta\n", | |
"from torch.distributions import constraints\n", | |
"from torch.distributions.exp_family import ExponentialFamily\n", | |
"from torch.distributions.utils import broadcast_all\n", | |
"from torch.distributions.dirichlet import Dirichlet\n", | |
"from numbers import Number" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"class ZOIBeta(ExponentialFamily):\n", | |
" \"\"\" Zero one inflated Beta distribution\n", | |
" \n", | |
" Args: \n", | |
" p (float or Tensor): Pr(y = 0)\n", | |
" q (float or Tensor): Pr(y = 1 | y != 0)\n", | |
" concentration1 (float or Tensor): 1st Beta dist. parameter \n", | |
" (often referred to as alpha)\n", | |
" concentration0 (float or Tensor): 2nd Beta dist. parameter\n", | |
" (often referred to as beta)\n", | |
" \"\"\"\n", | |
" \n", | |
" arg_constraints = {\n", | |
" 'p': constraints.unit_interval, \n", | |
" 'q': constraints.unit_interval, \n", | |
" 'concentration1': constraints.positive, \n", | |
" 'concentration0': constraints.positive\n", | |
" }\n", | |
" support = constraints.unit_interval # does this include 0 and 1?\n", | |
" has_rsample = False\n", | |
" \n", | |
" def __init__(self, p, q, concentration1, concentration0, validate_args=None):\n", | |
" if isinstance(concentration1, Number) and isinstance(concentration0, Number):\n", | |
" concentration1_concentration0 = torch.tensor([float(concentration1), float(concentration0)])\n", | |
" else:\n", | |
" concentration1, concentration0 = broadcast_all(concentration1, concentration0)\n", | |
" concentration1_concentration0 = torch.stack([concentration1, concentration0], -1)\n", | |
" self._dirichlet = Dirichlet(concentration1_concentration0)\n", | |
" self.log_p = torch.log(p)\n", | |
" self.log1m_p = torch.log(1 - p)\n", | |
" self.log_q = torch.log(q)\n", | |
" self.log1m_q = torch.log(1 - q)\n", | |
" super(ZOIBeta, self).__init__(self._dirichlet._batch_shape, validate_args=validate_args)\n", | |
" \n", | |
" def beta_lp(self, value):\n", | |
" if self._validate_args:\n", | |
" self._validate_sample(value)\n", | |
" heads_tails = torch.stack([value, 1.0 - value], -1)\n", | |
" return self._dirichlet.log_prob(heads_tails)\n", | |
" \n", | |
" def log_prob(self, value):\n", | |
" lp = torch.zeros_like(value, dtype = torch.float)\n", | |
" if any (0. < value < 1.): \n", | |
" beta_idx = torch.where(0. < value < 1.)\n", | |
" lp[beta_idx] = self.log1m_p + self.log1m_q + self.beta_lp(value[beta_idx])\n", | |
" lp[torch.where(value == 0.)] = self.log_p\n", | |
" lp[torch.where(value == 1.)] = self.log1m_p + self.log_q\n", | |
" return lp\n", | |
" " | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"zoib = ZOIBeta(p=torch.tensor(.5), \n", | |
" q=torch.tensor(.3), \n", | |
" concentration1=torch.tensor(1.), \n", | |
" concentration0=torch.tensor(1.))" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## Check that the log probabilities returned are what we expect\n", | |
"\n", | |
"For zeros:" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"tensor([[-0.6931]])" | |
] | |
}, | |
"execution_count": 4, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"zoib.log_prob(torch.tensor(0.).view(1, 1))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"tensor(-0.6931)" | |
] | |
}, | |
"execution_count": 5, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"zoib.log_p" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"For ones:" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"tensor([[-1.8971]])" | |
] | |
}, | |
"execution_count": 6, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"zoib.log_prob(torch.tensor(1.).view(1, 1))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"tensor(-1.8971)" | |
] | |
}, | |
"execution_count": 7, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"zoib.log1m_p + zoib.log_q" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"For proportions (if concentration1 and concentration0 are both 1, then we have a uniform Beta prior and we should get the same log probability for all values between 0 and 1:" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 8, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"tensor([[-1.0498]])" | |
] | |
}, | |
"execution_count": 8, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"zoib.log_prob(torch.tensor(.4).view(1, 1))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 9, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"tensor([[-1.0498]])" | |
] | |
}, | |
"execution_count": 9, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"zoib.log_prob(torch.tensor(.9).view(1, 1))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 10, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"tensor([[-1.0498]])" | |
] | |
}, | |
"execution_count": 10, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"zoib.log_prob(torch.tensor(.2).view(1, 1))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.7.7" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 4 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment