Skip to content

Instantly share code, notes, and snippets.

@pmineiro
Last active October 19, 2024 14:50
Show Gist options
  • Save pmineiro/390d6cc820c628d04dea991f8018c054 to your computer and use it in GitHub Desktop.
Save pmineiro/390d6cc820c628d04dea991f8018c054 to your computer and use it in GitHub Desktop.
--cb_dro demo for vowpal wabbit using covertype. To see the lift, note the "since last acc" column with and without --cb-dro.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"code_folding": []
},
"outputs": [],
"source": [
"class EasyAcc:\n",
" def __init__(self):\n",
" self.n = 0\n",
" self.sum = 0\n",
"\n",
" def __iadd__(self, other):\n",
" self.n += 1\n",
" self.sum += other\n",
" return self\n",
" \n",
" def __isub__(self, other):\n",
" self.n += 1\n",
" self.sum -= other\n",
" return self\n",
"\n",
" def mean(self):\n",
" return self.sum / max(self.n, 1)"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"code_folding": [
10,
30,
72
]
},
"outputs": [],
"source": [
"def cb_explore_adf_covertype_demo(usedro=False):\n",
" from collections import Counter\n",
" from sklearn.datasets import fetch_covtype\n",
" from sklearn.decomposition import PCA\n",
" from vowpalwabbit import pyvw\n",
" from math import ceil\n",
" import numpy as np\n",
" \n",
" np.random.seed(31337)\n",
"\n",
" if True:\n",
" Object = lambda **kwargs: type(\"Object\", (), kwargs)()\n",
"\n",
" cov = fetch_covtype()\n",
" cov.data = PCA(whiten=True).fit_transform(cov.data)\n",
" cov.target -= 1\n",
" assert 7 == len(Counter(cov.target))\n",
" npretrain = ceil(0.1 * cov.data.shape[0])\n",
" order = np.random.permutation(cov.data.shape[0])\n",
" pretrain = Object(data = cov.data[order[:npretrain]], target = cov.target[order[:npretrain]])\n",
" offpolicylearn = Object(data = cov.data[order[npretrain:]], target = cov.target[order[npretrain:]])\n",
" \n",
" print('****** pretraining phase (online learning) ******')\n",
" loggingacc, plog, piacc, sincelastplog, sincelastpiacc = [ EasyAcc() for _ in range(5) ]\n",
" print('{:<5s}\\t{:<9s}\\t{:<9s}\\t{:<9s}\\t{:<9s}\\t{:<9s}'.format(\n",
" 'n', 'log acc', 'plog', 'since last plog', 'pi acc', 'since last acc'\n",
" ),\n",
" flush=True)\n",
" \n",
" vw = pyvw.vw('--cb_explore_adf --cubic axx -q ax --ignore_linear x --noconstant')\n",
" for exno, (ex, label) in enumerate(zip(pretrain.data, pretrain.target)):\n",
" sharedfeat = ' '.join([ 'shared |x'] + [ f'{k}:{v}' for k, v in enumerate(ex) if v != 0 ])\n",
" exstr = '\\n'.join([ sharedfeat ] + [ f' |a {k+1}:1' for k in range(7) ])\n",
" pred = vw.predict(exstr)\n",
" probs = np.clip(np.array(pred), a_min=0, a_max=None)\n",
" probs /= np.sum(probs)\n",
" action = np.random.choice(7, p=probs)\n",
" loggingacc += 1 if action == label else 0\n",
" plog += probs[action]\n",
" sincelastplog += probs[action]\n",
" \n",
" argmaxaction = np.argmax(probs)\n",
" piacc += 1 if argmaxaction == label else 0\n",
" sincelastpiacc += 1 if argmaxaction == label else 0\n",
" \n",
" labelexstr = '\\n'.join([ sharedfeat ] + [ f' {l} |a {k+1}:1' \n",
" for k in range(7)\n",
" for l in (f'0:{-1 if action == label else 0}:{probs[k]}' if action == k else '',)\n",
" ])\n",
" \n",
" vw.learn(labelexstr)\n",
"\n",
" if (exno & (exno - 1) == 0):\n",
" print('{:<5d}\\t{:<9.5f}\\t{:<9.5f}\\t{:<9.5f}\\t{:<9.5f}\\t{:<9.5f}'.format(\n",
" loggingacc.n,\n",
" loggingacc.mean(),\n",
" plog.mean(),\n",
" sincelastplog.mean(),\n",
" piacc.mean(),\n",
" sincelastpiacc.mean()\n",
" ),\n",
" flush=True)\n",
" sincelastplog, sincelastpiacc = [ EasyAcc() for _ in range(2) ]\n",
" \n",
" print('****** off-policy learning phase ******')\n",
" loggingacc, plog, piacc, sincelastplog, sincelastpiacc = [ EasyAcc() for _ in range(5) ]\n",
" print('{:<5s}\\t{:<9s}\\t{:<9s}\\t{:<9s}\\t{:<9s}\\t{:<9s}'.format(\n",
" 'n', 'log acc', 'plog', 'since last plog', 'pi acc', 'since last acc'\n",
" ),\n",
" flush=True)\n",
" \n",
" offpolicyvw = pyvw.vw(f'--cb_adf --cubic axx -q ax --ignore_linear x --noconstant {\"--cb_dro\" if usedro else \"\"}') \n",
" for exno, (ex, label) in enumerate(zip(offpolicylearn.data, offpolicylearn.target)):\n",
" sharedfeat = ' '.join([ 'shared |x'] + [ f'{k}:{v}' for k, v in enumerate(ex) if v != 0 ])\n",
" exstr = '\\n'.join([ sharedfeat ] + [ f' |a {k+1}:1' for k in range(7) ])\n",
" pred = vw.predict(exstr)\n",
" probs = np.clip(np.array(pred), a_min=0, a_max=None)\n",
" probs /= np.sum(probs)\n",
" action = np.random.choice(7, p=probs)\n",
" loggingacc += 1 if action == label else 0\n",
" plog += probs[action]\n",
" sincelastplog += probs[action]\n",
" \n",
" offpred = offpolicyvw.predict(exstr)\n",
" argmaxaction = np.argmin(offpred)\n",
" piacc += 1 if argmaxaction == label else 0\n",
" sincelastpiacc += 1 if argmaxaction == label else 0\n",
" \n",
" labelexstr = '\\n'.join([ sharedfeat ] + [ f' {l} |a {k+1}:1' \n",
" for k in range(7)\n",
" for l in (f'0:{-1 if action == label else 0}:{probs[k]}' if action == k else '',)\n",
" ])\n",
" \n",
" offpolicyvw.learn(labelexstr)\n",
"\n",
" if (exno & (exno - 1) == 0):\n",
" print('{:<5d}\\t{:<9.5f}\\t{:<9.5f}\\t{:<9.5f}\\t{:<9.5f}\\t{:<9.5f}'.format(\n",
" loggingacc.n,\n",
" loggingacc.mean(),\n",
" plog.mean(),\n",
" sincelastplog.mean(),\n",
" piacc.mean(),\n",
" sincelastpiacc.mean()\n",
" ),\n",
" flush=True)\n",
" sincelastplog, sincelastpiacc = [ EasyAcc() for _ in range(2) ]"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"code_folding": [
10,
31
]
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"****** pretraining phase (online learning) ******\n",
"n \tlog acc \tplog \tsince last plog\tpi acc \tsince last acc\n",
"1 \t0.00000 \t0.14286 \t0.14286 \t0.00000 \t0.00000 \n",
"2 \t0.50000 \t0.14286 \t0.14286 \t0.00000 \t0.00000 \n",
"3 \t0.33333 \t0.41429 \t0.95714 \t0.00000 \t0.00000 \n",
"5 \t0.40000 \t0.63143 \t0.95714 \t0.20000 \t0.50000 \n",
"9 \t0.44444 \t0.68823 \t0.75923 \t0.44444 \t0.75000 \n",
"17 \t0.35294 \t0.66576 \t0.64048 \t0.47059 \t0.50000 \n",
"33 \t0.27273 \t0.68709 \t0.70975 \t0.33333 \t0.18750 \n",
"65 \t0.35385 \t0.77619 \t0.86808 \t0.36923 \t0.40625 \n",
"129 \t0.44961 \t0.84633 \t0.91756 \t0.45736 \t0.54688 \n",
"257 \t0.53307 \t0.89043 \t0.93488 \t0.53307 \t0.60938 \n",
"513 \t0.51462 \t0.90150 \t0.91261 \t0.52437 \t0.51562 \n",
"1025 \t0.52976 \t0.90983 \t0.91818 \t0.54634 \t0.56836 \n",
"2049 \t0.56418 \t0.91493 \t0.92003 \t0.58077 \t0.61523 \n",
"4097 \t0.59629 \t0.91702 \t0.91911 \t0.61801 \t0.65527 \n",
"8193 \t0.63066 \t0.91737 \t0.91771 \t0.65507 \t0.69214 \n",
"16385\t0.65621 \t0.91864 \t0.91992 \t0.68093 \t0.70679 \n",
"32769\t0.66535 \t0.91777 \t0.91690 \t0.69166 \t0.70239 \n",
"****** off-policy learning phase ******\n",
"n \tlog acc \tplog \tsince last plog\tpi acc \tsince last acc\n",
"1 \t0.00000 \t0.95714 \t0.95714 \t1.00000 \t1.00000 \n",
"2 \t0.50000 \t0.95714 \t0.95714 \t0.50000 \t0.00000 \n",
"3 \t0.33333 \t0.95714 \t0.95714 \t0.33333 \t0.00000 \n",
"5 \t0.60000 \t0.95714 \t0.95714 \t0.20000 \t0.00000 \n",
"9 \t0.77778 \t0.95714 \t0.95714 \t0.11111 \t0.00000 \n",
"17 \t0.76471 \t0.95714 \t0.95714 \t0.17647 \t0.25000 \n",
"33 \t0.72727 \t0.89957 \t0.83839 \t0.09091 \t0.00000 \n",
"65 \t0.70769 \t0.91330 \t0.92746 \t0.29231 \t0.50000 \n",
"129 \t0.68217 \t0.91296 \t0.91261 \t0.32558 \t0.35938 \n",
"257 \t0.63424 \t0.88691 \t0.86066 \t0.42023 \t0.51562 \n",
"513 \t0.64522 \t0.89233 \t0.89777 \t0.50487 \t0.58984 \n",
"1025 \t0.66146 \t0.90987 \t0.92746 \t0.56195 \t0.61914 \n",
"2049 \t0.67155 \t0.90846 \t0.90705 \t0.61689 \t0.67188 \n",
"4097 \t0.67708 \t0.91216 \t0.91586 \t0.63730 \t0.65771 \n",
"8193 \t0.68742 \t0.91331 \t0.91447 \t0.66337 \t0.68945 \n",
"16385\t0.69509 \t0.91522 \t0.91713 \t0.68459 \t0.70581 \n",
"32769\t0.69676 \t0.91618 \t0.91713 \t0.69398 \t0.70337 \n",
"65537\t0.69745 \t0.91590 \t0.91563 \t0.69762 \t0.70126 \n",
"131073\t0.69714 \t0.91613 \t0.91637 \t0.70680 \t0.71599 \n",
"262145\t0.69574 \t0.91587 \t0.91561 \t0.71269 \t0.71858 \n"
]
}
],
"source": [
"cb_explore_adf_covertype_demo()"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"****** pretraining phase (online learning) ******\n",
"n \tlog acc \tplog \tsince last plog\tpi acc \tsince last acc\n",
"1 \t0.00000 \t0.14286 \t0.14286 \t0.00000 \t0.00000 \n",
"2 \t0.50000 \t0.14286 \t0.14286 \t0.00000 \t0.00000 \n",
"3 \t0.33333 \t0.41429 \t0.95714 \t0.00000 \t0.00000 \n",
"5 \t0.40000 \t0.63143 \t0.95714 \t0.20000 \t0.50000 \n",
"9 \t0.44444 \t0.68823 \t0.75923 \t0.44444 \t0.75000 \n",
"17 \t0.35294 \t0.66576 \t0.64048 \t0.47059 \t0.50000 \n",
"33 \t0.27273 \t0.68709 \t0.70975 \t0.33333 \t0.18750 \n",
"65 \t0.35385 \t0.77619 \t0.86808 \t0.36923 \t0.40625 \n",
"129 \t0.44961 \t0.84633 \t0.91756 \t0.45736 \t0.54688 \n",
"257 \t0.53307 \t0.89043 \t0.93488 \t0.53307 \t0.60938 \n",
"513 \t0.51462 \t0.90150 \t0.91261 \t0.52437 \t0.51562 \n",
"1025 \t0.52976 \t0.90983 \t0.91818 \t0.54634 \t0.56836 \n",
"2049 \t0.56418 \t0.91493 \t0.92003 \t0.58077 \t0.61523 \n",
"4097 \t0.59629 \t0.91702 \t0.91911 \t0.61801 \t0.65527 \n",
"8193 \t0.63066 \t0.91737 \t0.91771 \t0.65507 \t0.69214 \n",
"16385\t0.65621 \t0.91864 \t0.91992 \t0.68093 \t0.70679 \n",
"32769\t0.66535 \t0.91777 \t0.91690 \t0.69166 \t0.70239 \n",
"****** off-policy learning phase ******\n",
"n \tlog acc \tplog \tsince last plog\tpi acc \tsince last acc\n",
"1 \t0.00000 \t0.95714 \t0.95714 \t1.00000 \t1.00000 \n",
"2 \t0.50000 \t0.95714 \t0.95714 \t0.50000 \t0.00000 \n",
"3 \t0.33333 \t0.95714 \t0.95714 \t0.33333 \t0.00000 \n",
"5 \t0.60000 \t0.95714 \t0.95714 \t0.20000 \t0.00000 \n",
"9 \t0.77778 \t0.95714 \t0.95714 \t0.11111 \t0.00000 \n",
"17 \t0.76471 \t0.95714 \t0.95714 \t0.11765 \t0.12500 \n",
"33 \t0.72727 \t0.89957 \t0.83839 \t0.06061 \t0.00000 \n",
"65 \t0.70769 \t0.91330 \t0.92746 \t0.12308 \t0.18750 \n",
"129 \t0.68217 \t0.91296 \t0.91261 \t0.17054 \t0.21875 \n",
"257 \t0.63424 \t0.88691 \t0.86066 \t0.30350 \t0.43750 \n",
"513 \t0.64522 \t0.89233 \t0.89777 \t0.45419 \t0.60547 \n",
"1025 \t0.66146 \t0.90987 \t0.92746 \t0.53854 \t0.62305 \n",
"2049 \t0.67155 \t0.90846 \t0.90705 \t0.59834 \t0.65820 \n",
"4097 \t0.67708 \t0.91216 \t0.91586 \t0.63241 \t0.66650 \n",
"8193 \t0.68742 \t0.91331 \t0.91447 \t0.66301 \t0.69360 \n",
"16385\t0.69509 \t0.91522 \t0.91713 \t0.68727 \t0.71155 \n",
"32769\t0.69676 \t0.91618 \t0.91713 \t0.70094 \t0.71460 \n",
"65537\t0.69745 \t0.91590 \t0.91563 \t0.70963 \t0.71832 \n",
"131073\t0.69714 \t0.91613 \t0.91637 \t0.71910 \t0.72856 \n",
"262145\t0.69574 \t0.91587 \t0.91561 \t0.72653 \t0.73396 \n"
]
}
],
"source": [
"cb_explore_adf_covertype_demo(usedro=True)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.4"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
@pmineiro
Copy link
Author

pmineiro commented Dec 5, 2020

In this gist I:

  • pre-train a logging policy using 10% of covertype, and then fix the logging policy thereafter
  • off-policy train another policy using data from the logging policy, either with or without the --cb_dro flag
  • --cb_dro improves the trained policy from 71.8% to 73.3% accuracy.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment