Skip to content

Instantly share code, notes, and snippets.

@kkew3
Created September 25, 2025 09:30
Show Gist options
  • Save kkew3/91bc663978f37b13890b6db4c222c99a to your computer and use it in GitHub Desktop.
Save kkew3/91bc663978f37b13890b6db4c222c99a to your computer and use it in GitHub Desktop.
In this notebook, we explore how to sample directly from the joint of m independent Bernoulli(p_i) truncated to the set where at least one Bernoulli is activated.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"id": "80a3d06d-bdb7-4ef7-a8a9-b86b9e5b1382",
"metadata": {},
"source": [
"# Truncated joint Bernoulli\n",
"\n",
"In this notebook, we explore how to sample from the joint of $m$ independent Bernoulli($p_i$) truncated to the set where at least one Bernoulli is activated.\n",
"Rejection sampling where all-0 samples are rejected is a natural idea, but it's extremely inefficient when the 1-probabilities are very small.\n",
"We, thus, need a cleverer method to sample from the truncated distribution directly.\n",
"\n",
"One key observation is that if at least one Bernoulli is activated, then there must exists a *first* activated Bernoulli.\n",
"Thus, we divide all possible samples into $m$ disjoint sets where the $i$th set has the $i$th Bernoulli being the first one activated.\n",
"Obviously, the strata weight is $w_i \\propto p_i \\prod_{j < i} (1 - p_j)$.\n",
"We sample by: 1) draw $i \\sim \\operatorname{softmax}(\\log p_i + \\sum_{j < i} \\log (1 - p_j))$, 2) fix the first $(i-1)$ variables to zero, the $i$th variable to one, and draw the rest according to the law $p_j$ ($j > i$).\n",
"The probability mass of a sample $\\mathbf{s}$ is given by $p_0(\\mathbf{s}) / (1-p_0(\\mathbf{0}))$ if $\\mathbf{s}$ is not all zero, and 0 otherwise.\n",
"Here, $p_0$ denotes the probability mass function (pmf) of the original joint independent Bernoullis."
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "55261d83-8967-4c3f-835a-4414bdbe5062",
"metadata": {
"execution": {
"iopub.execute_input": "2025-09-25T09:26:20.615693Z",
"iopub.status.busy": "2025-09-25T09:26:20.615478Z",
"iopub.status.idle": "2025-09-25T09:26:21.406533Z",
"shell.execute_reply": "2025-09-25T09:26:21.406081Z",
"shell.execute_reply.started": "2025-09-25T09:26:20.615670Z"
}
},
"outputs": [],
"source": [
"import torch\n",
"\n",
"# Draw n samples from joint Bernoulli such that not all are tails.\n",
"class NonZeroBernoulli:\n",
" def __init__(self, p):\n",
" m = p.size(0)\n",
" log_p = torch.log(p)\n",
" log1m_p = torch.log1p(-p)\n",
" c0_log1m_p, sum_log1m_p = torch.cat((torch.zeros(1), torch.cumsum(log1m_p, 0))).split([m, 1])\n",
" self.w = torch.softmax(log_p + c0_log1m_p, 0) # strata weights for the ith coin being the first head\n",
" self.p = p\n",
" self.log_p = log_p\n",
" self.log1m_p = log1m_p\n",
" self.log_z = torch.log(-torch.expm1(sum_log1m_p)) # the log normalization constant\n",
"\n",
" def sample(self, n):\n",
" m = self.p.size(0)\n",
" s = torch.empty(n, m, dtype=torch.bool).bernoulli_(self.p)\n",
" i = torch.multinomial(self.w, n, replacement=True)\n",
" s.masked_fill_(torch.arange(m).repeat(n, 1) < i.unsqueeze(1), 0)\n",
" s[torch.arange(n), i] = 1\n",
" return s\n",
"\n",
" def log_prob(self, s):\n",
" log_p0 = torch.where(s.bool(), self.log_p, self.log1m_p).sum(-1)\n",
" return (log_p0 - self.log_z).masked_fill(~s.any(-1), float('-inf'))"
]
},
{
"cell_type": "markdown",
"id": "84b662c8-ea8b-4fef-bf0b-a3124bc37b89",
"metadata": {},
"source": [
"Construct a joint Bernoulli where the all-0 probability dominates.\n",
"Then draw samples from our truncated distribution."
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "e2f55618-0073-4801-92b5-8a35952daf3f",
"metadata": {
"execution": {
"iopub.execute_input": "2025-09-25T09:26:21.406898Z",
"iopub.status.busy": "2025-09-25T09:26:21.406768Z",
"iopub.status.idle": "2025-09-25T09:26:21.431259Z",
"shell.execute_reply": "2025-09-25T09:26:21.430866Z",
"shell.execute_reply.started": "2025-09-25T09:26:21.406889Z"
}
},
"outputs": [
{
"data": {
"text/plain": [
"tensor([[1, 0],\n",
" [0, 1],\n",
" [1, 0],\n",
" [1, 0],\n",
" [0, 1],\n",
" [1, 0],\n",
" [0, 1],\n",
" [0, 1],\n",
" [0, 1],\n",
" [0, 1]])"
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"p = torch.tensor([1e-8, 2e-8])\n",
"nzb = NonZeroBernoulli(p)\n",
"nzb.sample(10).long()"
]
},
{
"cell_type": "markdown",
"id": "cc60ab28-8a6e-47dd-b64b-296894bc3ec9",
"metadata": {},
"source": [
"Inspect the probability masses for all cases:"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "24ff062d-b105-4b57-9aa3-dda200ffea5d",
"metadata": {
"execution": {
"iopub.execute_input": "2025-09-25T09:26:21.431620Z",
"iopub.status.busy": "2025-09-25T09:26:21.431520Z",
"iopub.status.idle": "2025-09-25T09:26:21.434541Z",
"shell.execute_reply": "2025-09-25T09:26:21.434259Z",
"shell.execute_reply.started": "2025-09-25T09:26:21.431610Z"
}
},
"outputs": [
{
"data": {
"text/plain": [
"tensor([0.0000e+00, 6.6667e-01, 3.3333e-01, 6.6667e-09])"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"nzb.log_prob(torch.tensor([[0, 0], [0, 1], [1, 0], [1, 1]])).exp()"
]
},
{
"cell_type": "markdown",
"id": "6876239b-ec12-477d-ae1d-a3bdc7df03d8",
"metadata": {},
"source": [
"Ensure the masses sum to one:"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "3c2f25ac-a173-4f06-898e-9272d478139d",
"metadata": {
"execution": {
"iopub.execute_input": "2025-09-25T09:26:21.434867Z",
"iopub.status.busy": "2025-09-25T09:26:21.434773Z",
"iopub.status.idle": "2025-09-25T09:26:21.437424Z",
"shell.execute_reply": "2025-09-25T09:26:21.437164Z",
"shell.execute_reply.started": "2025-09-25T09:26:21.434857Z"
}
},
"outputs": [
{
"data": {
"text/plain": [
"tensor(1.0000)"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"nzb.log_prob(torch.tensor([[0, 0], [0, 1], [1, 0], [1, 1]])).exp().sum(0)"
]
},
{
"cell_type": "markdown",
"id": "0e9d0b19-74bf-412f-9ee5-669dca762dda",
"metadata": {},
"source": [
"**Remark**.\n",
"The same idea can be extended to: a) at least $k$ variables are activated by recursively applying the stratification, b) at most $k$ variables are activated by reducing to at least $k$ variables being *not* activated, etc."
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.5"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment