Last active
April 11, 2024 17:44
-
-
Save ghimiremukesh/c107dad1d2a64f1055dcf830f53be113 to your computer and use it in GitHub Desktop.
test_beer_quiche_primal_dual.ipynb
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "view-in-github", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"<a href=\"https://colab.research.google.com/gist/ghimiremukesh/c107dad1d2a64f1055dcf830f53be113/test_beer_quiche_primal_dual.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 9, | |
"id": "34ff4a80-5daa-4b89-9e15-f84ba234611f", | |
"metadata": { | |
"id": "34ff4a80-5daa-4b89-9e15-f84ba234611f" | |
}, | |
"outputs": [], | |
"source": [ | |
"import numpy as np\n", | |
"from typing import Dict" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 10, | |
"id": "8cd766cf-9ae7-486b-9be9-c5eb85bb93f5", | |
"metadata": { | |
"id": "8cd766cf-9ae7-486b-9be9-c5eb85bb93f5" | |
}, | |
"outputs": [], | |
"source": [ | |
"p1_actions = ['B', 'Q']\n", | |
"p2_actions = ['b', 'd']" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 11, | |
"id": "94b77a11-b116-431c-8943-a623b4a8441b", | |
"metadata": { | |
"id": "94b77a11-b116-431c-8943-a623b4a8441b" | |
}, | |
"outputs": [], | |
"source": [ | |
"class BeerQuiche():\n", | |
" @staticmethod\n", | |
" def get_payoff(history, p1_type):\n", | |
" payoff = 0\n", | |
" if p1_type == 'T': # tough\n", | |
" if history == 'Bb':\n", | |
" payoff = 2\n", | |
" elif history == 'Bd':\n", | |
" payoff = 1\n", | |
" elif history == 'Qb':\n", | |
" payoff = 1\n", | |
" else:\n", | |
" payoff = 0\n", | |
" else:\n", | |
" if history == 'Bb':\n", | |
" payoff = -2\n", | |
" elif history == 'Bd':\n", | |
" payoff = 0\n", | |
" elif history == 'Qb':\n", | |
" payoff = -1\n", | |
" else:\n", | |
" payoff = 2\n", | |
"\n", | |
" return payoff" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 12, | |
"id": "5f9a6b82-82fa-41f4-907b-47ef5c512e4d", | |
"metadata": { | |
"id": "5f9a6b82-82fa-41f4-907b-47ef5c512e4d" | |
}, | |
"outputs": [], | |
"source": [ | |
"## implement niave cfr-like trainer for beer-quiche\n", | |
"\n", | |
"# best response solver ---> use expected value calculation from the appendix in the paper\n", | |
"def best_responses(B, Q, b, q):\n", | |
" e_bully = - (Q - 2 * q)/(Q + 2*q)\n", | |
" e_defer = - (4*q)/(Q + 2*q)\n", | |
"\n", | |
" IQ = np.array([e_bully, e_defer])\n", | |
"\n", | |
" e_bully = (4 * b - 2 * B)/(B + 2*b)\n", | |
" e_defer = - B / (B + 2*b)\n", | |
"\n", | |
" IB = np.array([e_bully, e_defer])\n", | |
"\n", | |
" return np.argmax(IQ), np.argmax(IB)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 13, | |
"id": "d98d4333-5dfa-4d7d-99e5-9b74f1ad68f9", | |
"metadata": { | |
"id": "d98d4333-5dfa-4d7d-99e5-9b74f1ad68f9" | |
}, | |
"outputs": [], | |
"source": [ | |
"class information_set():\n", | |
" def __init__(self):\n", | |
" self.cumulative_regrets = np.zeros(shape=2)\n", | |
" self.strategy_sum = np.zeros(shape=2)\n", | |
" self.num_actions = 2\n", | |
"\n", | |
" def normalize(self, strategy):\n", | |
" \"\"\"Normalize strategy. If no positive regrets, strategy is unif. random\"\"\"\n", | |
" if sum(strategy) > 0:\n", | |
" strategy /= sum(strategy)\n", | |
" else:\n", | |
" strategy = np.ones(self.num_actions)/self.num_actions\n", | |
"\n", | |
" return strategy\n", | |
"\n", | |
" def get_strategy(self, reach_probability=1):\n", | |
" \"\"\"Return regret matching strategy\"\"\"\n", | |
" strategy = np.maximum(0, self.cumulative_regrets)\n", | |
" strategy = self.normalize(strategy)\n", | |
"\n", | |
" self.strategy_sum += reach_probability * strategy\n", | |
"\n", | |
" return strategy\n", | |
"\n", | |
" def get_average_strategy(self):\n", | |
" return self.normalize(self.strategy_sum.copy())" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 14, | |
"id": "48976ab7-f9a7-4259-9eb5-779e56ca06aa", | |
"metadata": { | |
"id": "48976ab7-f9a7-4259-9eb5-779e56ca06aa" | |
}, | |
"outputs": [], | |
"source": [ | |
"class cfr_trainer_primal():\n", | |
" def __init__(self, types):\n", | |
" self.infoset_map: Dict[str, info_set] = {}\n", | |
" for t in types:\n", | |
" self.infoset_map[t] = information_set()\n", | |
"\n", | |
" def get_all_strategies(self):\n", | |
" a_T = self.infoset_map['T'].get_strategy()\n", | |
" a_W = self.infoset_map['W'].get_strategy()\n", | |
"\n", | |
" B, Q = a_T[0], a_T[1]\n", | |
" b, q = a_W[0], a_W[1]\n", | |
"\n", | |
" return B, Q, b, q\n", | |
"\n", | |
" def get_information_set(self, type_and_history):\n", | |
" \"\"\"add if needed and return\"\"\"\n", | |
" if type_and_history not in self.infoset_map:\n", | |
" raise \"info sets for P1 must be initialized in the beginning for the toy game\"\n", | |
"\n", | |
" return self.infoset_map[type_and_history]\n", | |
"\n", | |
" def cfr(self, p1_type, history):\n", | |
" # compute best-response\n", | |
" B, Q, b, q = self.get_all_strategies() # get action probs for p1\n", | |
" a_Q, a_B = best_responses(B, Q, b, q) # get best response for info set I_Q and I_B\n", | |
"\n", | |
" info_set = self.get_information_set(p1_type + history)\n", | |
" strategy = info_set.get_strategy()\n", | |
"\n", | |
" payoff_values = np.zeros(2)\n", | |
" for ix, action in enumerate(p1_actions):\n", | |
" p2_a_ix = a_Q if action == 'Q' else a_B\n", | |
" history = action + p2_actions[p2_a_ix]\n", | |
"\n", | |
" payoff_values[ix] = BeerQuiche.get_payoff(history, p1_type)\n", | |
"\n", | |
" expected_payoff = payoff_values.dot(strategy)\n", | |
" for ix, action in enumerate(p1_actions):\n", | |
" info_set.cumulative_regrets[ix] += (payoff_values[ix] - expected_payoff)\n", | |
"\n", | |
"\n", | |
" return expected_payoff\n", | |
"\n", | |
"\n", | |
" def train(self, num_itrs):\n", | |
" util = 0\n", | |
" p = 1/3\n", | |
" types = [1, 2]\n", | |
" for _ in range(num_itrs):\n", | |
" p1_t = np.random.choice(types, p=[p, 1-p])\n", | |
" p1_type = 'T' if p1_t == 1 else 'W'\n", | |
" history = ''\n", | |
" reach_probabilities = np.ones(2)\n", | |
" util += self.cfr(p1_type, history)\n", | |
" return util" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 25, | |
"id": "f0da4bcf-1d30-4e80-af3b-c0bb1b897bf7", | |
"metadata": { | |
"id": "f0da4bcf-1d30-4e80-af3b-c0bb1b897bf7" | |
}, | |
"outputs": [], | |
"source": [ | |
"trainer = cfr_trainer_primal('TW')\n", | |
"num_itrs = 1000\n", | |
"util = trainer.train(num_itrs)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 26, | |
"id": "e0df1770-00de-4b0f-9c85-3ef2bb1c5e82", | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "e0df1770-00de-4b0f-9c85-3ef2bb1c5e82", | |
"outputId": "a7cb6f5a-939a-47cd-8096-bcac2015ca67" | |
}, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"T: [1. 0.]\n", | |
"W: [0.25 0.75]\n" | |
] | |
} | |
], | |
"source": [ | |
"for name, info_set in sorted(trainer.infoset_map.items(),):\n", | |
" print(f\"{name}: {np.round(info_set.get_average_strategy(), 2)}\")" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [], | |
"metadata": { | |
"id": "JKh1SMUSST2D" | |
}, | |
"id": "JKh1SMUSST2D", | |
"execution_count": null, | |
"outputs": [] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3 (ipykernel)", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.9.15" | |
}, | |
"colab": { | |
"provenance": [], | |
"include_colab_link": true | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 5 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment