Created
September 25, 2025 09:30
-
-
Save kkew3/91bc663978f37b13890b6db4c222c99a to your computer and use it in GitHub Desktop.
In this notebook, we explore how to sample directly from the joint of m independent Bernoulli(p_i) truncated to the set where at least one Bernoulli is activated.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| { | |
| "cells": [ | |
| { | |
| "cell_type": "markdown", | |
| "id": "80a3d06d-bdb7-4ef7-a8a9-b86b9e5b1382", | |
| "metadata": {}, | |
| "source": [ | |
| "# Truncated joint Bernoulli\n", | |
| "\n", | |
| "In this notebook, we explore how to sample from the joint of $m$ independent Bernoulli($p_i$) truncated to the set where at least one Bernoulli is activated.\n", | |
| "Rejection sampling where all-0 samples are rejected is a natural idea, but it's extremely inefficient when the 1-probabilities are very small.\n", | |
| "We, thus, need a cleverer method to sample from the truncated distribution directly.\n", | |
| "\n", | |
| "One key observation is that if at least one Bernoulli is activated, then there must exists a *first* activated Bernoulli.\n", | |
| "Thus, we divide all possible samples into $m$ disjoint sets where the $i$th set has the $i$th Bernoulli being the first one activated.\n", | |
| "Obviously, the strata weight is $w_i \\propto p_i \\prod_{j < i} (1 - p_j)$.\n", | |
| "We sample by: 1) draw $i \\sim \\operatorname{softmax}(\\log p_i + \\sum_{j < i} \\log (1 - p_j))$, 2) fix the first $(i-1)$ variables to zero, the $i$th variable to one, and draw the rest according to the law $p_j$ ($j > i$).\n", | |
| "The probability mass of a sample $\\mathbf{s}$ is given by $p_0(\\mathbf{s}) / (1-p_0(\\mathbf{0}))$ if $\\mathbf{s}$ is not all zero, and 0 otherwise.\n", | |
| "Here, $p_0$ denotes the probability mass function (pmf) of the original joint independent Bernoullis." | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 1, | |
| "id": "55261d83-8967-4c3f-835a-4414bdbe5062", | |
| "metadata": { | |
| "execution": { | |
| "iopub.execute_input": "2025-09-25T09:26:20.615693Z", | |
| "iopub.status.busy": "2025-09-25T09:26:20.615478Z", | |
| "iopub.status.idle": "2025-09-25T09:26:21.406533Z", | |
| "shell.execute_reply": "2025-09-25T09:26:21.406081Z", | |
| "shell.execute_reply.started": "2025-09-25T09:26:20.615670Z" | |
| } | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "import torch\n", | |
| "\n", | |
| "# Draw n samples from joint Bernoulli such that not all are tails.\n", | |
| "class NonZeroBernoulli:\n", | |
| " def __init__(self, p):\n", | |
| " m = p.size(0)\n", | |
| " log_p = torch.log(p)\n", | |
| " log1m_p = torch.log1p(-p)\n", | |
| " c0_log1m_p, sum_log1m_p = torch.cat((torch.zeros(1), torch.cumsum(log1m_p, 0))).split([m, 1])\n", | |
| " self.w = torch.softmax(log_p + c0_log1m_p, 0) # strata weights for the ith coin being the first head\n", | |
| " self.p = p\n", | |
| " self.log_p = log_p\n", | |
| " self.log1m_p = log1m_p\n", | |
| " self.log_z = torch.log(-torch.expm1(sum_log1m_p)) # the log normalization constant\n", | |
| "\n", | |
| " def sample(self, n):\n", | |
| " m = self.p.size(0)\n", | |
| " s = torch.empty(n, m, dtype=torch.bool).bernoulli_(self.p)\n", | |
| " i = torch.multinomial(self.w, n, replacement=True)\n", | |
| " s.masked_fill_(torch.arange(m).repeat(n, 1) < i.unsqueeze(1), 0)\n", | |
| " s[torch.arange(n), i] = 1\n", | |
| " return s\n", | |
| "\n", | |
| " def log_prob(self, s):\n", | |
| " log_p0 = torch.where(s.bool(), self.log_p, self.log1m_p).sum(-1)\n", | |
| " return (log_p0 - self.log_z).masked_fill(~s.any(-1), float('-inf'))" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "id": "84b662c8-ea8b-4fef-bf0b-a3124bc37b89", | |
| "metadata": {}, | |
| "source": [ | |
| "Construct a joint Bernoulli where the all-0 probability dominates.\n", | |
| "Then draw samples from our truncated distribution." | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 2, | |
| "id": "e2f55618-0073-4801-92b5-8a35952daf3f", | |
| "metadata": { | |
| "execution": { | |
| "iopub.execute_input": "2025-09-25T09:26:21.406898Z", | |
| "iopub.status.busy": "2025-09-25T09:26:21.406768Z", | |
| "iopub.status.idle": "2025-09-25T09:26:21.431259Z", | |
| "shell.execute_reply": "2025-09-25T09:26:21.430866Z", | |
| "shell.execute_reply.started": "2025-09-25T09:26:21.406889Z" | |
| } | |
| }, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "tensor([[1, 0],\n", | |
| " [0, 1],\n", | |
| " [1, 0],\n", | |
| " [1, 0],\n", | |
| " [0, 1],\n", | |
| " [1, 0],\n", | |
| " [0, 1],\n", | |
| " [0, 1],\n", | |
| " [0, 1],\n", | |
| " [0, 1]])" | |
| ] | |
| }, | |
| "execution_count": 2, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "p = torch.tensor([1e-8, 2e-8])\n", | |
| "nzb = NonZeroBernoulli(p)\n", | |
| "nzb.sample(10).long()" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "id": "cc60ab28-8a6e-47dd-b64b-296894bc3ec9", | |
| "metadata": {}, | |
| "source": [ | |
| "Inspect the probability masses for all cases:" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 3, | |
| "id": "24ff062d-b105-4b57-9aa3-dda200ffea5d", | |
| "metadata": { | |
| "execution": { | |
| "iopub.execute_input": "2025-09-25T09:26:21.431620Z", | |
| "iopub.status.busy": "2025-09-25T09:26:21.431520Z", | |
| "iopub.status.idle": "2025-09-25T09:26:21.434541Z", | |
| "shell.execute_reply": "2025-09-25T09:26:21.434259Z", | |
| "shell.execute_reply.started": "2025-09-25T09:26:21.431610Z" | |
| } | |
| }, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "tensor([0.0000e+00, 6.6667e-01, 3.3333e-01, 6.6667e-09])" | |
| ] | |
| }, | |
| "execution_count": 3, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "nzb.log_prob(torch.tensor([[0, 0], [0, 1], [1, 0], [1, 1]])).exp()" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "id": "6876239b-ec12-477d-ae1d-a3bdc7df03d8", | |
| "metadata": {}, | |
| "source": [ | |
| "Ensure the masses sum to one:" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 4, | |
| "id": "3c2f25ac-a173-4f06-898e-9272d478139d", | |
| "metadata": { | |
| "execution": { | |
| "iopub.execute_input": "2025-09-25T09:26:21.434867Z", | |
| "iopub.status.busy": "2025-09-25T09:26:21.434773Z", | |
| "iopub.status.idle": "2025-09-25T09:26:21.437424Z", | |
| "shell.execute_reply": "2025-09-25T09:26:21.437164Z", | |
| "shell.execute_reply.started": "2025-09-25T09:26:21.434857Z" | |
| } | |
| }, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "tensor(1.0000)" | |
| ] | |
| }, | |
| "execution_count": 4, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "nzb.log_prob(torch.tensor([[0, 0], [0, 1], [1, 0], [1, 1]])).exp().sum(0)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "id": "0e9d0b19-74bf-412f-9ee5-669dca762dda", | |
| "metadata": {}, | |
| "source": [ | |
| "**Remark**.\n", | |
| "The same idea can be extended to: a) at least $k$ variables are activated by recursively applying the stratification, b) at most $k$ variables are activated by reducing to at least $k$ variables being *not* activated, etc." | |
| ] | |
| } | |
| ], | |
| "metadata": { | |
| "kernelspec": { | |
| "display_name": "Python 3 (ipykernel)", | |
| "language": "python", | |
| "name": "python3" | |
| }, | |
| "language_info": { | |
| "codemirror_mode": { | |
| "name": "ipython", | |
| "version": 3 | |
| }, | |
| "file_extension": ".py", | |
| "mimetype": "text/x-python", | |
| "name": "python", | |
| "nbconvert_exporter": "python", | |
| "pygments_lexer": "ipython3", | |
| "version": "3.11.5" | |
| } | |
| }, | |
| "nbformat": 4, | |
| "nbformat_minor": 5 | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment