Skip to content

Instantly share code, notes, and snippets.

@ghimiremukesh
Last active April 11, 2024 17:44
Show Gist options
  • Save ghimiremukesh/c107dad1d2a64f1055dcf830f53be113 to your computer and use it in GitHub Desktop.
Save ghimiremukesh/c107dad1d2a64f1055dcf830f53be113 to your computer and use it in GitHub Desktop.
test_beer_quiche_primal_dual.ipynb
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/ghimiremukesh/c107dad1d2a64f1055dcf830f53be113/test_beer_quiche_primal_dual.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "34ff4a80-5daa-4b89-9e15-f84ba234611f",
"metadata": {
"id": "34ff4a80-5daa-4b89-9e15-f84ba234611f"
},
"outputs": [],
"source": [
"import numpy as np\n",
"from typing import Dict"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "8cd766cf-9ae7-486b-9be9-c5eb85bb93f5",
"metadata": {
"id": "8cd766cf-9ae7-486b-9be9-c5eb85bb93f5"
},
"outputs": [],
"source": [
"p1_actions = ['B', 'Q']\n",
"p2_actions = ['b', 'd']"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "94b77a11-b116-431c-8943-a623b4a8441b",
"metadata": {
"id": "94b77a11-b116-431c-8943-a623b4a8441b"
},
"outputs": [],
"source": [
"class BeerQuiche():\n",
" @staticmethod\n",
" def get_payoff(history, p1_type):\n",
" payoff = 0\n",
" if p1_type == 'T': # tough\n",
" if history == 'Bb':\n",
" payoff = 2\n",
" elif history == 'Bd':\n",
" payoff = 1\n",
" elif history == 'Qb':\n",
" payoff = 1\n",
" else:\n",
" payoff = 0\n",
" else:\n",
" if history == 'Bb':\n",
" payoff = -2\n",
" elif history == 'Bd':\n",
" payoff = 0\n",
" elif history == 'Qb':\n",
" payoff = -1\n",
" else:\n",
" payoff = 2\n",
"\n",
" return payoff"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "5f9a6b82-82fa-41f4-907b-47ef5c512e4d",
"metadata": {
"id": "5f9a6b82-82fa-41f4-907b-47ef5c512e4d"
},
"outputs": [],
"source": [
"## implement niave cfr-like trainer for beer-quiche\n",
"\n",
"# best response solver ---> use expected value calculation from the appendix in the paper\n",
"def best_responses(B, Q, b, q):\n",
" e_bully = - (Q - 2 * q)/(Q + 2*q)\n",
" e_defer = - (4*q)/(Q + 2*q)\n",
"\n",
" IQ = np.array([e_bully, e_defer])\n",
"\n",
" e_bully = (4 * b - 2 * B)/(B + 2*b)\n",
" e_defer = - B / (B + 2*b)\n",
"\n",
" IB = np.array([e_bully, e_defer])\n",
"\n",
" return np.argmax(IQ), np.argmax(IB)"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "d98d4333-5dfa-4d7d-99e5-9b74f1ad68f9",
"metadata": {
"id": "d98d4333-5dfa-4d7d-99e5-9b74f1ad68f9"
},
"outputs": [],
"source": [
"class information_set():\n",
" def __init__(self):\n",
" self.cumulative_regrets = np.zeros(shape=2)\n",
" self.strategy_sum = np.zeros(shape=2)\n",
" self.num_actions = 2\n",
"\n",
" def normalize(self, strategy):\n",
" \"\"\"Normalize strategy. If no positive regrets, strategy is unif. random\"\"\"\n",
" if sum(strategy) > 0:\n",
" strategy /= sum(strategy)\n",
" else:\n",
" strategy = np.ones(self.num_actions)/self.num_actions\n",
"\n",
" return strategy\n",
"\n",
" def get_strategy(self, reach_probability=1):\n",
" \"\"\"Return regret matching strategy\"\"\"\n",
" strategy = np.maximum(0, self.cumulative_regrets)\n",
" strategy = self.normalize(strategy)\n",
"\n",
" self.strategy_sum += reach_probability * strategy\n",
"\n",
" return strategy\n",
"\n",
" def get_average_strategy(self):\n",
" return self.normalize(self.strategy_sum.copy())"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "48976ab7-f9a7-4259-9eb5-779e56ca06aa",
"metadata": {
"id": "48976ab7-f9a7-4259-9eb5-779e56ca06aa"
},
"outputs": [],
"source": [
"class cfr_trainer_primal():\n",
" def __init__(self, types):\n",
" self.infoset_map: Dict[str, info_set] = {}\n",
" for t in types:\n",
" self.infoset_map[t] = information_set()\n",
"\n",
" def get_all_strategies(self):\n",
" a_T = self.infoset_map['T'].get_strategy()\n",
" a_W = self.infoset_map['W'].get_strategy()\n",
"\n",
" B, Q = a_T[0], a_T[1]\n",
" b, q = a_W[0], a_W[1]\n",
"\n",
" return B, Q, b, q\n",
"\n",
" def get_information_set(self, type_and_history):\n",
" \"\"\"add if needed and return\"\"\"\n",
" if type_and_history not in self.infoset_map:\n",
" raise \"info sets for P1 must be initialized in the beginning for the toy game\"\n",
"\n",
" return self.infoset_map[type_and_history]\n",
"\n",
" def cfr(self, p1_type, history):\n",
" # compute best-response\n",
" B, Q, b, q = self.get_all_strategies() # get action probs for p1\n",
" a_Q, a_B = best_responses(B, Q, b, q) # get best response for info set I_Q and I_B\n",
"\n",
" info_set = self.get_information_set(p1_type + history)\n",
" strategy = info_set.get_strategy()\n",
"\n",
" payoff_values = np.zeros(2)\n",
" for ix, action in enumerate(p1_actions):\n",
" p2_a_ix = a_Q if action == 'Q' else a_B\n",
" history = action + p2_actions[p2_a_ix]\n",
"\n",
" payoff_values[ix] = BeerQuiche.get_payoff(history, p1_type)\n",
"\n",
" expected_payoff = payoff_values.dot(strategy)\n",
" for ix, action in enumerate(p1_actions):\n",
" info_set.cumulative_regrets[ix] += (payoff_values[ix] - expected_payoff)\n",
"\n",
"\n",
" return expected_payoff\n",
"\n",
"\n",
" def train(self, num_itrs):\n",
" util = 0\n",
" p = 1/3\n",
" types = [1, 2]\n",
" for _ in range(num_itrs):\n",
" p1_t = np.random.choice(types, p=[p, 1-p])\n",
" p1_type = 'T' if p1_t == 1 else 'W'\n",
" history = ''\n",
" reach_probabilities = np.ones(2)\n",
" util += self.cfr(p1_type, history)\n",
" return util"
]
},
{
"cell_type": "code",
"execution_count": 25,
"id": "f0da4bcf-1d30-4e80-af3b-c0bb1b897bf7",
"metadata": {
"id": "f0da4bcf-1d30-4e80-af3b-c0bb1b897bf7"
},
"outputs": [],
"source": [
"trainer = cfr_trainer_primal('TW')\n",
"num_itrs = 1000\n",
"util = trainer.train(num_itrs)"
]
},
{
"cell_type": "code",
"execution_count": 26,
"id": "e0df1770-00de-4b0f-9c85-3ef2bb1c5e82",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "e0df1770-00de-4b0f-9c85-3ef2bb1c5e82",
"outputId": "a7cb6f5a-939a-47cd-8096-bcac2015ca67"
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"T: [1. 0.]\n",
"W: [0.25 0.75]\n"
]
}
],
"source": [
"for name, info_set in sorted(trainer.infoset_map.items(),):\n",
" print(f\"{name}: {np.round(info_set.get_average_strategy(), 2)}\")"
]
},
{
"cell_type": "code",
"source": [],
"metadata": {
"id": "JKh1SMUSST2D"
},
"id": "JKh1SMUSST2D",
"execution_count": null,
"outputs": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.15"
},
"colab": {
"provenance": [],
"include_colab_link": true
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment