Last active
October 19, 2024 14:50
-
-
Save pmineiro/390d6cc820c628d04dea991f8018c054 to your computer and use it in GitHub Desktop.
--cb_dro demo for vowpal wabbit using covertype. To see the lift, note the "since last acc" column with and without --cb-dro.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": { | |
"code_folding": [] | |
}, | |
"outputs": [], | |
"source": [ | |
"class EasyAcc:\n", | |
" def __init__(self):\n", | |
" self.n = 0\n", | |
" self.sum = 0\n", | |
"\n", | |
" def __iadd__(self, other):\n", | |
" self.n += 1\n", | |
" self.sum += other\n", | |
" return self\n", | |
" \n", | |
" def __isub__(self, other):\n", | |
" self.n += 1\n", | |
" self.sum -= other\n", | |
" return self\n", | |
"\n", | |
" def mean(self):\n", | |
" return self.sum / max(self.n, 1)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"metadata": { | |
"code_folding": [ | |
10, | |
30, | |
72 | |
] | |
}, | |
"outputs": [], | |
"source": [ | |
"def cb_explore_adf_covertype_demo(usedro=False):\n", | |
" from collections import Counter\n", | |
" from sklearn.datasets import fetch_covtype\n", | |
" from sklearn.decomposition import PCA\n", | |
" from vowpalwabbit import pyvw\n", | |
" from math import ceil\n", | |
" import numpy as np\n", | |
" \n", | |
" np.random.seed(31337)\n", | |
"\n", | |
" if True:\n", | |
" Object = lambda **kwargs: type(\"Object\", (), kwargs)()\n", | |
"\n", | |
" cov = fetch_covtype()\n", | |
" cov.data = PCA(whiten=True).fit_transform(cov.data)\n", | |
" cov.target -= 1\n", | |
" assert 7 == len(Counter(cov.target))\n", | |
" npretrain = ceil(0.1 * cov.data.shape[0])\n", | |
" order = np.random.permutation(cov.data.shape[0])\n", | |
" pretrain = Object(data = cov.data[order[:npretrain]], target = cov.target[order[:npretrain]])\n", | |
" offpolicylearn = Object(data = cov.data[order[npretrain:]], target = cov.target[order[npretrain:]])\n", | |
" \n", | |
" print('****** pretraining phase (online learning) ******')\n", | |
" loggingacc, plog, piacc, sincelastplog, sincelastpiacc = [ EasyAcc() for _ in range(5) ]\n", | |
" print('{:<5s}\\t{:<9s}\\t{:<9s}\\t{:<9s}\\t{:<9s}\\t{:<9s}'.format(\n", | |
" 'n', 'log acc', 'plog', 'since last plog', 'pi acc', 'since last acc'\n", | |
" ),\n", | |
" flush=True)\n", | |
" \n", | |
" vw = pyvw.vw('--cb_explore_adf --cubic axx -q ax --ignore_linear x --noconstant')\n", | |
" for exno, (ex, label) in enumerate(zip(pretrain.data, pretrain.target)):\n", | |
" sharedfeat = ' '.join([ 'shared |x'] + [ f'{k}:{v}' for k, v in enumerate(ex) if v != 0 ])\n", | |
" exstr = '\\n'.join([ sharedfeat ] + [ f' |a {k+1}:1' for k in range(7) ])\n", | |
" pred = vw.predict(exstr)\n", | |
" probs = np.clip(np.array(pred), a_min=0, a_max=None)\n", | |
" probs /= np.sum(probs)\n", | |
" action = np.random.choice(7, p=probs)\n", | |
" loggingacc += 1 if action == label else 0\n", | |
" plog += probs[action]\n", | |
" sincelastplog += probs[action]\n", | |
" \n", | |
" argmaxaction = np.argmax(probs)\n", | |
" piacc += 1 if argmaxaction == label else 0\n", | |
" sincelastpiacc += 1 if argmaxaction == label else 0\n", | |
" \n", | |
" labelexstr = '\\n'.join([ sharedfeat ] + [ f' {l} |a {k+1}:1' \n", | |
" for k in range(7)\n", | |
" for l in (f'0:{-1 if action == label else 0}:{probs[k]}' if action == k else '',)\n", | |
" ])\n", | |
" \n", | |
" vw.learn(labelexstr)\n", | |
"\n", | |
" if (exno & (exno - 1) == 0):\n", | |
" print('{:<5d}\\t{:<9.5f}\\t{:<9.5f}\\t{:<9.5f}\\t{:<9.5f}\\t{:<9.5f}'.format(\n", | |
" loggingacc.n,\n", | |
" loggingacc.mean(),\n", | |
" plog.mean(),\n", | |
" sincelastplog.mean(),\n", | |
" piacc.mean(),\n", | |
" sincelastpiacc.mean()\n", | |
" ),\n", | |
" flush=True)\n", | |
" sincelastplog, sincelastpiacc = [ EasyAcc() for _ in range(2) ]\n", | |
" \n", | |
" print('****** off-policy learning phase ******')\n", | |
" loggingacc, plog, piacc, sincelastplog, sincelastpiacc = [ EasyAcc() for _ in range(5) ]\n", | |
" print('{:<5s}\\t{:<9s}\\t{:<9s}\\t{:<9s}\\t{:<9s}\\t{:<9s}'.format(\n", | |
" 'n', 'log acc', 'plog', 'since last plog', 'pi acc', 'since last acc'\n", | |
" ),\n", | |
" flush=True)\n", | |
" \n", | |
" offpolicyvw = pyvw.vw(f'--cb_adf --cubic axx -q ax --ignore_linear x --noconstant {\"--cb_dro\" if usedro else \"\"}') \n", | |
" for exno, (ex, label) in enumerate(zip(offpolicylearn.data, offpolicylearn.target)):\n", | |
" sharedfeat = ' '.join([ 'shared |x'] + [ f'{k}:{v}' for k, v in enumerate(ex) if v != 0 ])\n", | |
" exstr = '\\n'.join([ sharedfeat ] + [ f' |a {k+1}:1' for k in range(7) ])\n", | |
" pred = vw.predict(exstr)\n", | |
" probs = np.clip(np.array(pred), a_min=0, a_max=None)\n", | |
" probs /= np.sum(probs)\n", | |
" action = np.random.choice(7, p=probs)\n", | |
" loggingacc += 1 if action == label else 0\n", | |
" plog += probs[action]\n", | |
" sincelastplog += probs[action]\n", | |
" \n", | |
" offpred = offpolicyvw.predict(exstr)\n", | |
" argmaxaction = np.argmin(offpred)\n", | |
" piacc += 1 if argmaxaction == label else 0\n", | |
" sincelastpiacc += 1 if argmaxaction == label else 0\n", | |
" \n", | |
" labelexstr = '\\n'.join([ sharedfeat ] + [ f' {l} |a {k+1}:1' \n", | |
" for k in range(7)\n", | |
" for l in (f'0:{-1 if action == label else 0}:{probs[k]}' if action == k else '',)\n", | |
" ])\n", | |
" \n", | |
" offpolicyvw.learn(labelexstr)\n", | |
"\n", | |
" if (exno & (exno - 1) == 0):\n", | |
" print('{:<5d}\\t{:<9.5f}\\t{:<9.5f}\\t{:<9.5f}\\t{:<9.5f}\\t{:<9.5f}'.format(\n", | |
" loggingacc.n,\n", | |
" loggingacc.mean(),\n", | |
" plog.mean(),\n", | |
" sincelastplog.mean(),\n", | |
" piacc.mean(),\n", | |
" sincelastpiacc.mean()\n", | |
" ),\n", | |
" flush=True)\n", | |
" sincelastplog, sincelastpiacc = [ EasyAcc() for _ in range(2) ]" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"metadata": { | |
"code_folding": [ | |
10, | |
31 | |
] | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"****** pretraining phase (online learning) ******\n", | |
"n \tlog acc \tplog \tsince last plog\tpi acc \tsince last acc\n", | |
"1 \t0.00000 \t0.14286 \t0.14286 \t0.00000 \t0.00000 \n", | |
"2 \t0.50000 \t0.14286 \t0.14286 \t0.00000 \t0.00000 \n", | |
"3 \t0.33333 \t0.41429 \t0.95714 \t0.00000 \t0.00000 \n", | |
"5 \t0.40000 \t0.63143 \t0.95714 \t0.20000 \t0.50000 \n", | |
"9 \t0.44444 \t0.68823 \t0.75923 \t0.44444 \t0.75000 \n", | |
"17 \t0.35294 \t0.66576 \t0.64048 \t0.47059 \t0.50000 \n", | |
"33 \t0.27273 \t0.68709 \t0.70975 \t0.33333 \t0.18750 \n", | |
"65 \t0.35385 \t0.77619 \t0.86808 \t0.36923 \t0.40625 \n", | |
"129 \t0.44961 \t0.84633 \t0.91756 \t0.45736 \t0.54688 \n", | |
"257 \t0.53307 \t0.89043 \t0.93488 \t0.53307 \t0.60938 \n", | |
"513 \t0.51462 \t0.90150 \t0.91261 \t0.52437 \t0.51562 \n", | |
"1025 \t0.52976 \t0.90983 \t0.91818 \t0.54634 \t0.56836 \n", | |
"2049 \t0.56418 \t0.91493 \t0.92003 \t0.58077 \t0.61523 \n", | |
"4097 \t0.59629 \t0.91702 \t0.91911 \t0.61801 \t0.65527 \n", | |
"8193 \t0.63066 \t0.91737 \t0.91771 \t0.65507 \t0.69214 \n", | |
"16385\t0.65621 \t0.91864 \t0.91992 \t0.68093 \t0.70679 \n", | |
"32769\t0.66535 \t0.91777 \t0.91690 \t0.69166 \t0.70239 \n", | |
"****** off-policy learning phase ******\n", | |
"n \tlog acc \tplog \tsince last plog\tpi acc \tsince last acc\n", | |
"1 \t0.00000 \t0.95714 \t0.95714 \t1.00000 \t1.00000 \n", | |
"2 \t0.50000 \t0.95714 \t0.95714 \t0.50000 \t0.00000 \n", | |
"3 \t0.33333 \t0.95714 \t0.95714 \t0.33333 \t0.00000 \n", | |
"5 \t0.60000 \t0.95714 \t0.95714 \t0.20000 \t0.00000 \n", | |
"9 \t0.77778 \t0.95714 \t0.95714 \t0.11111 \t0.00000 \n", | |
"17 \t0.76471 \t0.95714 \t0.95714 \t0.17647 \t0.25000 \n", | |
"33 \t0.72727 \t0.89957 \t0.83839 \t0.09091 \t0.00000 \n", | |
"65 \t0.70769 \t0.91330 \t0.92746 \t0.29231 \t0.50000 \n", | |
"129 \t0.68217 \t0.91296 \t0.91261 \t0.32558 \t0.35938 \n", | |
"257 \t0.63424 \t0.88691 \t0.86066 \t0.42023 \t0.51562 \n", | |
"513 \t0.64522 \t0.89233 \t0.89777 \t0.50487 \t0.58984 \n", | |
"1025 \t0.66146 \t0.90987 \t0.92746 \t0.56195 \t0.61914 \n", | |
"2049 \t0.67155 \t0.90846 \t0.90705 \t0.61689 \t0.67188 \n", | |
"4097 \t0.67708 \t0.91216 \t0.91586 \t0.63730 \t0.65771 \n", | |
"8193 \t0.68742 \t0.91331 \t0.91447 \t0.66337 \t0.68945 \n", | |
"16385\t0.69509 \t0.91522 \t0.91713 \t0.68459 \t0.70581 \n", | |
"32769\t0.69676 \t0.91618 \t0.91713 \t0.69398 \t0.70337 \n", | |
"65537\t0.69745 \t0.91590 \t0.91563 \t0.69762 \t0.70126 \n", | |
"131073\t0.69714 \t0.91613 \t0.91637 \t0.70680 \t0.71599 \n", | |
"262145\t0.69574 \t0.91587 \t0.91561 \t0.71269 \t0.71858 \n" | |
] | |
} | |
], | |
"source": [ | |
"cb_explore_adf_covertype_demo()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"****** pretraining phase (online learning) ******\n", | |
"n \tlog acc \tplog \tsince last plog\tpi acc \tsince last acc\n", | |
"1 \t0.00000 \t0.14286 \t0.14286 \t0.00000 \t0.00000 \n", | |
"2 \t0.50000 \t0.14286 \t0.14286 \t0.00000 \t0.00000 \n", | |
"3 \t0.33333 \t0.41429 \t0.95714 \t0.00000 \t0.00000 \n", | |
"5 \t0.40000 \t0.63143 \t0.95714 \t0.20000 \t0.50000 \n", | |
"9 \t0.44444 \t0.68823 \t0.75923 \t0.44444 \t0.75000 \n", | |
"17 \t0.35294 \t0.66576 \t0.64048 \t0.47059 \t0.50000 \n", | |
"33 \t0.27273 \t0.68709 \t0.70975 \t0.33333 \t0.18750 \n", | |
"65 \t0.35385 \t0.77619 \t0.86808 \t0.36923 \t0.40625 \n", | |
"129 \t0.44961 \t0.84633 \t0.91756 \t0.45736 \t0.54688 \n", | |
"257 \t0.53307 \t0.89043 \t0.93488 \t0.53307 \t0.60938 \n", | |
"513 \t0.51462 \t0.90150 \t0.91261 \t0.52437 \t0.51562 \n", | |
"1025 \t0.52976 \t0.90983 \t0.91818 \t0.54634 \t0.56836 \n", | |
"2049 \t0.56418 \t0.91493 \t0.92003 \t0.58077 \t0.61523 \n", | |
"4097 \t0.59629 \t0.91702 \t0.91911 \t0.61801 \t0.65527 \n", | |
"8193 \t0.63066 \t0.91737 \t0.91771 \t0.65507 \t0.69214 \n", | |
"16385\t0.65621 \t0.91864 \t0.91992 \t0.68093 \t0.70679 \n", | |
"32769\t0.66535 \t0.91777 \t0.91690 \t0.69166 \t0.70239 \n", | |
"****** off-policy learning phase ******\n", | |
"n \tlog acc \tplog \tsince last plog\tpi acc \tsince last acc\n", | |
"1 \t0.00000 \t0.95714 \t0.95714 \t1.00000 \t1.00000 \n", | |
"2 \t0.50000 \t0.95714 \t0.95714 \t0.50000 \t0.00000 \n", | |
"3 \t0.33333 \t0.95714 \t0.95714 \t0.33333 \t0.00000 \n", | |
"5 \t0.60000 \t0.95714 \t0.95714 \t0.20000 \t0.00000 \n", | |
"9 \t0.77778 \t0.95714 \t0.95714 \t0.11111 \t0.00000 \n", | |
"17 \t0.76471 \t0.95714 \t0.95714 \t0.11765 \t0.12500 \n", | |
"33 \t0.72727 \t0.89957 \t0.83839 \t0.06061 \t0.00000 \n", | |
"65 \t0.70769 \t0.91330 \t0.92746 \t0.12308 \t0.18750 \n", | |
"129 \t0.68217 \t0.91296 \t0.91261 \t0.17054 \t0.21875 \n", | |
"257 \t0.63424 \t0.88691 \t0.86066 \t0.30350 \t0.43750 \n", | |
"513 \t0.64522 \t0.89233 \t0.89777 \t0.45419 \t0.60547 \n", | |
"1025 \t0.66146 \t0.90987 \t0.92746 \t0.53854 \t0.62305 \n", | |
"2049 \t0.67155 \t0.90846 \t0.90705 \t0.59834 \t0.65820 \n", | |
"4097 \t0.67708 \t0.91216 \t0.91586 \t0.63241 \t0.66650 \n", | |
"8193 \t0.68742 \t0.91331 \t0.91447 \t0.66301 \t0.69360 \n", | |
"16385\t0.69509 \t0.91522 \t0.91713 \t0.68727 \t0.71155 \n", | |
"32769\t0.69676 \t0.91618 \t0.91713 \t0.70094 \t0.71460 \n", | |
"65537\t0.69745 \t0.91590 \t0.91563 \t0.70963 \t0.71832 \n", | |
"131073\t0.69714 \t0.91613 \t0.91637 \t0.71910 \t0.72856 \n", | |
"262145\t0.69574 \t0.91587 \t0.91561 \t0.72653 \t0.73396 \n" | |
] | |
} | |
], | |
"source": [ | |
"cb_explore_adf_covertype_demo(usedro=True)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.7.4" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
In this gist I:
--cb_dro
flag--cb_dro
improves the trained policy from 71.8% to 73.3% accuracy.