Skip to content

Instantly share code, notes, and snippets.

@pmineiro
Last active August 3, 2022 19:20
Show Gist options
  • Save pmineiro/8957704a398bc8d5f10f814a543f8b46 to your computer and use it in GitHub Desktop.
Save pmineiro/8957704a398bc8d5f10f814a543f8b46 to your computer and use it in GitHub Desktop.
An off-policy confidence sequence suitable for general purposes which supports oblivious data censorship.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"id": "38cb8dc5",
"metadata": {},
"source": [
"# Reference OPE CS Impl"
]
},
{
"cell_type": "markdown",
"id": "376afba0",
"metadata": {},
"source": [
"An off-policy confidence sequence suitable for general purposes which supports oblivious data censorship."
]
},
{
"cell_type": "markdown",
"id": "ed4a91ca",
"metadata": {},
"source": [
"## API\n",
"\n",
"The public API is:\n",
" * `Constructor`$(r_\\min, r_\\max, \\text{adjust})$: takes an initial reward range and a boolean saying whether to adjust automatically.\n",
" * `addobs(w_t, r_t, p_drop=0, n_drop=None)`: Observe an importance weighted reward.\n",
" * `p_drop` and `n_drop` model an oblivious random data censorship process, e.g., a logging system randomly dropping events in response to queue backpressure.\n",
" * Oblivious means “conditionally independent of the reward given the context and action”: ${p_{\\text{drop}}}_t \\perp r_{a_t} | a_t, x_t$.\n",
" * `p_drop` is the probability of dropping an event, should be $\\in [0, 1)$.\n",
" * `n_drop` is the number of events that were dropped prior to this event being not-dropped. Should be `None` or non-negative.\n",
" * Since `addobs()` is being called, this event is presumed not-dropped.\n",
" * If `n_drop is None`, we will assume `ndrop` is the mean of a negative binomial process.\n",
" * If you know `n_drop`, it's better to use the actual value than this assumption.\n",
" * `getci`$(\\alpha)$: Return a CI at confidence level $(1 - \\alpha)$."
]
},
{
"cell_type": "markdown",
"id": "29e5928b",
"metadata": {},
"source": [
"## Coverage Guarantee"
]
},
{
"cell_type": "markdown",
"id": "85ec39a4",
"metadata": {},
"source": [
"For the coverage guarantee you must certify the preconditions:\n",
" * $w_t = \\frac{d\\pi_t}{d \\mu_t}(a_t)$ is the correct importance weight:\n",
" * $\\mu_t$ is the logging policy (i.e., the distribution from which the action is drawn):\n",
" * $\\pi_t$ is the policy being evaluated (i.e., the policy whose mean is being estimated).\n",
" * $r_t \\in [ r_{\\min}, r_{\\max} ]$ with probability 1.\n",
" * $r_t = r_t(a_t)$ where $a_t \\sim \\mu_t$.\n",
" * `p_drop` is the actual probability of dropping this event (which, by virtue of being here, is not-dropped).\n",
" * `n_drop` is the actual number of dropped events prior to this event being not-dropped.\n",
" \n",
"Then you get the following guarantee $$\n",
"\\mathrm{Pr}\\left( \\forall t: \\frac{1}{t} \\sum_{s=1}^t \\mathbb{E}_{\\substack{t-1 \\\\ a \\sim \\pi_t}}\\left[r_t(a)\\right] \\in \\text{getci}(\\alpha) \\right) \\geq 1 - \\alpha.\n",
"$$\n",
"This guarantee is:\n",
" * time-uniform coverage (simultaneously valid for all sample sizes)\n",
" * of the running mean of the policy sequence (evaluated in an environment where `p_drop=0` and `n_drop=0` always)\n",
" * the environment can [adaptively](https://math.stackexchange.com/questions/1794875/what-is-the-difference-between-an-adapted-process-and-a-predictable-process) change each timestep\n",
" * this includes `p_drop` adaptively changing each timestep\n",
" * the policy being evaluated can [predictably](https://math.stackexchange.com/questions/1794875/what-is-the-difference-between-an-adapted-process-and-a-predictable-process) change with each timestep"
]
},
{
"cell_type": "markdown",
"id": "6da1e3b7",
"metadata": {},
"source": [
"### Reward Range Robustness"
]
},
{
"cell_type": "markdown",
"id": "7d3e116e",
"metadata": {},
"source": [
"To facilitate unknown reward ranges we have two strategies:\n",
" * `adjust=False`: clips the realized reward to be in the constructor supplied range $[r_{\\min}, r_{\\max}]$ and provides the coverage guarantee on this modified random variable.\n",
" * `adjust=True`: expands the reward range if an observed value exceeds the constructor supplied range.\n",
" * in this case, the coverage guarantee is conditioned on observing the complete range (if initially incorrectly specified), which can cover a value very different than the running mean if extreme reward values are rare."
]
},
{
"cell_type": "code",
"execution_count": 25,
"id": "058f455c",
"metadata": {
"code_folding": [
0,
31,
110,
130,
148
]
},
"outputs": [],
"source": [
"class IncrementalFsum:\n",
" \"\"\" Incremental version of https://en.wikipedia.org/wiki/Kahan_summation_algorithm \"\"\"\n",
"\n",
" def __init__(self):\n",
" self.partials = []\n",
"\n",
" def __iadd__(self, x):\n",
" i = 0\n",
" for y in self.partials:\n",
" if abs(x) < abs(y):\n",
" x, y = y, x\n",
" hi = x + y\n",
" lo = y - (hi - x)\n",
" if lo:\n",
" self.partials[i] = lo\n",
" i += 1\n",
" x = hi\n",
" self.partials[i:] = [x]\n",
" return self\n",
"\n",
" def __add__(self, other):\n",
" result = IncrementalFsum()\n",
" result.partials = deepcopy(self.partials)\n",
" for y in other.partials:\n",
" result += y\n",
" return result\n",
"\n",
" def __float__(self):\n",
" return sum(self.partials, 0.0)\n",
"\n",
"class EmpBernDynDropCS(object):\n",
" def __init__(self, rmin=0, rmax=1, adjust=True):\n",
" super().__init__()\n",
" \n",
" assert rmin <= rmax, (rmin, rmax)\n",
" \n",
" self.rho = 1\n",
" self.rmin = rmin\n",
" self.rmax = rmax\n",
" self.adjust = adjust\n",
" \n",
" self.t = 0\n",
"\n",
" self.sumwsqrsq = IncrementalFsum()\n",
" self.sumwsqr = IncrementalFsum()\n",
" self.sumwsq = IncrementalFsum()\n",
" self.sumwr = IncrementalFsum()\n",
" self.sumw = IncrementalFsum()\n",
" self.sumwrxhatlow = IncrementalFsum()\n",
" self.sumwxhatlow = IncrementalFsum()\n",
" self.sumxhatlowsq = IncrementalFsum()\n",
" self.sumwrxhathigh = IncrementalFsum()\n",
" self.sumwxhathigh = IncrementalFsum()\n",
" self.sumxhathighsq = IncrementalFsum()\n",
" \n",
" def addobs(self, w, r, p_drop=0, n_drop=None):\n",
" assert w >= 0\n",
" assert 0 <= p_drop < 1\n",
" assert n_drop is None or n_drop >= 0\n",
" \n",
" if not self.adjust:\n",
" r = min(self.rmax, max(self.rmin, r))\n",
" else:\n",
" self.rmin = min(self.rmin, r)\n",
" self.rmax = max(self.rmax, r)\n",
" \n",
" if n_drop is None:\n",
" n_drop = p_drop / (1 - p_drop)\n",
" \n",
" if n_drop > 0:\n",
" import scipy.special as sc\n",
" \n",
" # we have to simulate presenting n_drop events with w=0 in a row, which we can do in closed form\n",
" # Sum[(a/(b + s))^2, { s, 0, n - 1 }] \n",
" # a^2 PolyGamma[1,b]-a^2 PolyGamma[1,b+n] \n",
" \n",
" sumXlow = (float(self.sumwr) - float(self.sumw) * self.rmin) / (self.rmax - self.rmin)\n",
" alow = sumXlow + 1/2\n",
" blow = self.t + 1\n",
" self.sumxhatlowsq += alow**2 * (sc.polygamma(1, blow).item() - sc.polygamma(1, blow + n_drop).item())\n",
" \n",
" sumXhigh = (float(self.sumw) * self.rmax - float(self.sumwr)) / (self.rmax - self.rmin)\n",
" ahigh = sumXhigh + 1/2\n",
" bhigh = self.t + 1\n",
" self.sumxhathighsq += ahigh**2 * (sc.polygamma(1, bhigh).item() - sc.polygamma(1, bhigh + n_drop).item())\n",
" \n",
" self.t += n_drop\n",
" \n",
" sumXlow = (float(self.sumwr) - float(self.sumw) * self.rmin) / (self.rmax - self.rmin)\n",
" Xhatlow = (sumXlow + 1/2) / (self.t + 1)\n",
" sumXhigh = (float(self.sumw) * self.rmax - float(self.sumwr)) / (self.rmax - self.rmin)\n",
" Xhathigh = (sumXhigh + 1/2) / (self.t + 1)\n",
" \n",
" w /= (1 - p_drop)\n",
" \n",
" self.sumwsqrsq += (w * r)**2\n",
" self.sumwsqr += w**2 * r\n",
" self.sumwsq += w**2\n",
" self.sumwr += w * r\n",
" self.sumw += w\n",
" self.sumwrxhatlow += w * r * Xhatlow\n",
" self.sumwxhatlow += w * Xhatlow\n",
" self.sumxhatlowsq += Xhatlow**2\n",
" self.sumwrxhathigh += w * r * Xhathigh\n",
" self.sumwxhathigh += w * Xhathigh\n",
" self.sumxhathighsq += Xhathigh**2\n",
" \n",
" self.t += 1\n",
" \n",
" def getci(self, alpha):\n",
" if self.t == 0 or self.rmin == self.rmax:\n",
" return [self.rmin, self.rmax]\n",
" \n",
" sumvlow = ( (float(self.sumwsqrsq) - 2 * self.rmin * float(self.sumwsqr) + self.rmin**2 * float(self.sumwsq)) / (self.rmax - self.rmin)**2\n",
" - 2 * (float(self.sumwrxhatlow) - self.rmin * float(self.sumwxhatlow)) / (self.rmax - self.rmin)\n",
" + float(self.sumxhatlowsq)\n",
" )\n",
" sumXlow = (float(self.sumwr) - float(self.sumw) * self.rmin) / (self.rmax - self.rmin)\n",
" l = self.__lblogwealth(t=self.t, sumXt=sumXlow, v=sumvlow, rho=self.rho, alpha=alpha/2)\n",
" \n",
" sumvhigh = ( (float(self.sumwsqrsq) - 2 * self.rmax * float(self.sumwsqr) + self.rmax**2 * float(self.sumwsq)) / (self.rmax - self.rmin)**2\n",
" + 2 * (float(self.sumwrxhathigh) - self.rmax * float(self.sumwxhathigh)) / (self.rmax - self.rmin)\n",
" + float(self.sumxhathighsq)\n",
" )\n",
" sumXhigh = (float(self.sumw) * self.rmax - float(self.sumwr)) / (self.rmax - self.rmin)\n",
" u = 1 - self.__lblogwealth(t=self.t, sumXt=sumXhigh, v=sumvhigh, rho=self.rho, alpha=alpha/2)\n",
" \n",
" return self.rmin + l * (self.rmax - self.rmin), self.rmin + u * (self.rmax - self.rmin)\n",
"\n",
" def __logwealth(self, *, s, v, rho):\n",
" from math import log\n",
"\n",
" def loggammalowerinc(*, a, x):\n",
" import scipy.special as sc\n",
"\n",
" return log(sc.gammainc(a, x)) + sc.loggamma(a)\n",
" \n",
" assert s + v + rho > 0\n",
" assert rho > 0\n",
"\n",
" return (s + v\n",
" + rho * log(rho)\n",
" - (v + rho) * log(s + v + rho)\n",
" + loggammalowerinc(a = v + rho, x = s + v + rho)\n",
" - loggammalowerinc(a = rho, x = rho)\n",
" )\n",
"\n",
" def __lblogwealth(self, *, t, sumXt, v, rho, alpha):\n",
" from math import log\n",
" import scipy.optimize as so\n",
"\n",
" assert 0 < alpha < 1, alpha\n",
" thres = -log(alpha)\n",
"\n",
" minmu = 0\n",
" logwealthminmu = self.__logwealth(s=sumXt, v=v, rho=rho)\n",
"\n",
" if logwealthminmu <= thres:\n",
" return minmu\n",
" \n",
" maxmu = min(1, sumXt/t)\n",
" logwealthmaxmu = self.__logwealth(s=sumXt - t * maxmu, v=v, rho=rho)\n",
"\n",
" if logwealthmaxmu >= thres:\n",
" return maxmu\n",
"\n",
" res = so.root_scalar(f = lambda mu: self.__logwealth(s=sumXt - t * mu, v=v, rho=rho) - thres,\n",
" method = 'brentq',\n",
" bracket = [ minmu, maxmu ])\n",
" assert res.converged, res\n",
" return res.root"
]
},
{
"cell_type": "code",
"execution_count": 29,
"id": "74bb2b73",
"metadata": {
"code_folding": [
1,
6,
48
]
},
"outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 1152x432 with 2 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
},
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 1152x432 with 2 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
},
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 1152x432 with 2 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"class DataGen(object):\n",
" def __init__(self, *, wmax, expwsq, truemu, seed, rvals):\n",
" import numpy as np\n",
" import scipy.optimize as so\n",
" import random\n",
" \n",
" if False:\n",
" # { 0, 1, wmax } \\times { 0, 1 } -> 6 values -> we need 6 constraints\n",
" # 1 = sum_i p_i\n",
" # 1 = sum_i w_i p_i\n",
" # E[w^2] = sum_i w_i^2 p_i\n",
" # logging policy value = sum_i r_i p_i\n",
" # evaluated policy value = sum_i w_i r_i p_i\n",
" # we need 1 more constraint to be unique ...\n",
" # SURPRISE: just the above 5 constraints can be infeasible ...\n",
" # instead just minimize the logging policy value subject to other constraints\n",
" # this makes the distribution very difficult to lower bound\n",
" pass\n",
" \n",
" self.gen = random.Random(seed)\n",
" self.wmax = wmax\n",
" self.expwsq = expwsq\n",
" self.truemu = truemu\n",
" self.population = [ (w, r) for w in (0, 1, wmax,) for r in rvals ]\n",
" \n",
" c = [ r for (w, r) in self.population ] \n",
" A_eq = [\n",
" [ 1 for (w, r) in self.population ],\n",
" [ w for (w, r) in self.population ],\n",
" [ w**2 for (w, r) in self.population ],\n",
" [ w*r for (w, r) in self.population ],\n",
" ]\n",
" b_eq = [ 1, 1, expwsq, truemu, ]\n",
" \n",
" res = so.linprog(np.array(c), A_eq=A_eq, b_eq=b_eq)\n",
" assert res.success, res\n",
" self.probs = res.x\n",
" self.logmu = res.fun\n",
" \n",
" ewwm1r = self.probs.dot([ w * (w - 1) * r for (w, r) in self.population ])\n",
" ewm1sq = self.probs.dot([ (w - 1)**2 for (w, r) in self.population])\n",
" self.kappalowstar = -ewwm1r/ewm1sq if ewm1sq > 0 else 0\n",
" ewwm11mr = self.probs.dot([ w * (w - 1) * (1 - r) for (w, r) in self.population ])\n",
" self.kappahighstar = -ewwm11mr/ewm1sq if ewm1sq > 0 else 0\n",
" \n",
" self._expOp = lambda func: sum(p * func(w) for p, (w, _) in zip(self.probs, self.population))\n",
" self.clippedtruemu = sum(p * w * r for p, (w, r) in zip(self.probs, [ (w, r) for w in (0, 1, wmax) for r in (0, 1)]))\n",
"\n",
" def genobs(self):\n",
" w, r = self.gen.choices(population=self.population,\n",
" weights=self.probs,\n",
" )[0]\n",
" return w, r, self._expOp\n",
"\n",
"def megasim(*, T, datagen, wmax, adjust, seed, dt=1, alpha = 0.05):\n",
" import itertools\n",
" from matplotlib import pyplot as plt \n",
" import numpy as np\n",
" \n",
" cs = EmpBernDynDropCS(adjust=adjust)\n",
" gen = np.random.RandomState(seed)\n",
" pdroplow = 9/10\n",
" pdrophigh = 99/100\n",
" \n",
" wrz = []\n",
" lbz, ubz = [], []\n",
" n_drop = 0\n",
" \n",
" for t in range(T):\n",
" w, r, expOp = datagen.genobs()\n",
" p_drop = gen.uniform(low=pdroplow, high=pdrophigh)\n",
" should_drop = (gen.uniform(low=0, high=1) <= p_drop)\n",
" if should_drop:\n",
" n_drop += 1\n",
" else:\n",
" cs.addobs(w, r, p_drop, n_drop)\n",
" n_drop = 0\n",
" \n",
" if t % dt == 0:\n",
" wrz.append(w*r)\n",
" l, u = cs.getci(alpha=0.05)\n",
" lbz.append(l)\n",
" ubz.append(u)\n",
" \n",
" fig, ax = plt.subplots(1, 2)\n",
" fig.set_size_inches(16, 6)\n",
" ax[0].plot(list(itertools.accumulate(wrz)))\n",
" ax[0].set_ylabel('sum(wr)')\n",
" color = next(ax[1]._get_lines.prop_cycler)['color']\n",
" ax[1].plot(lbz, label='CS', color=color)\n",
" ax[1].plot(ubz, color=color)\n",
" color = next(ax[1]._get_lines.prop_cycler)['color']\n",
" ax[1].plot([datagen.truemu if adjust else datagen.clippedtruemu]*len(lbz), linestyle='dashed', color=color, label=f'{datagen.truemu if adjust else datagen.clippedtruemu:.3g}')\n",
" ax[1].set_xlabel(f'examples (x {dt})')\n",
" ax[1].set_ylabel('raw bounds')\n",
" ax[1].set_xscale('log')\n",
" ax[1].legend()\n",
" \n",
" pstr = ','.join([f'{v:.3g}' for v in datagen.probs])\n",
" fig.suptitle(f'expwsq = {datagen.expwsq} wmax={datagen.wmax} truemu={datagen.truemu} p={pstr} pdrop $\\in [{pdroplow, pdrophigh}]$')\n",
" \n",
" return None\n",
"\n",
"def flass(seed):\n",
" dg = DataGen(wmax=10, expwsq=2, truemu=1/2, rvals=(-2, 2), seed=seed)\n",
" megasim(T=100000, wmax=10, dt=10, datagen=dg, adjust=True, seed=seed+1)\n",
" megasim(T=100000, wmax=10, dt=10, datagen=dg, adjust=False, seed=seed+1)\n",
" dg = DataGen(wmax=10, expwsq=2, truemu=1/2, rvals=(0, 1), seed=seed)\n",
" megasim(T=100000, wmax=10, dt=10, datagen=dg, adjust=False, seed=seed+1)\n",
"\n",
"flass(4545)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "223a98ab",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.13"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment