Last active
April 10, 2024 21:38
-
-
Save ghimiremukesh/ff5c44fdb53fb4892fb55703a7c5d891 to your computer and use it in GitHub Desktop.
beerquiche_cfr_primal_dual.ipynb
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "view-in-github", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"<a href=\"https://colab.research.google.com/gist/ghimiremukesh/ff5c44fdb53fb4892fb55703a7c5d891/beerquiche_cfr.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "c98ba361", | |
"metadata": { | |
"id": "c98ba361" | |
}, | |
"source": [ | |
"### CFR -- Beer Quiche Game" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"id": "43d052ee", | |
"metadata": { | |
"id": "43d052ee" | |
}, | |
"outputs": [], | |
"source": [ | |
"import numpy as np\n", | |
"from typing import List, Dict\n", | |
"import random\n", | |
"import sys" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"def best_response(info_set, p1_T, p1_W):\n", | |
" \"\"\"Compute best response action for P2 (from P1's perspective) given P1's strategy\"\"\"\n", | |
" if info_set == 'Q':\n", | |
" Q = p1_T\n", | |
" q = p1_W\n", | |
"\n", | |
" e_bully = - (Q - 2 * q)/(Q+2 * q)\n", | |
" e_defer = -(4 * q)/(Q + + 2 * q)\n", | |
" else:\n", | |
" B = p1_T\n", | |
" b = p1_W\n", | |
"\n", | |
" e_bully = (4*b - 2 * B)/(B + 2 * b)\n", | |
" e_defer = -B/(B + 2 * b)\n", | |
"\n", | |
" evs = np.array([e_b, e_d])\n", | |
"\n", | |
" return np.argmax(evs)" | |
], | |
"metadata": { | |
"id": "W137Brd962OW" | |
}, | |
"id": "W137Brd962OW", | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 20, | |
"id": "a7699947", | |
"metadata": { | |
"id": "a7699947" | |
}, | |
"outputs": [], | |
"source": [ | |
"p1_actions = ['B', 'Q']\n", | |
"p2_actions = ['b', 'd']\n", | |
"\n", | |
"class information_set():\n", | |
" def __init__(self):\n", | |
" self.cumulative_regrets = np.zeros(shape=2)\n", | |
" self.strategy_sum = np.zeros(shape=2)\n", | |
" self.num_actions = 2 # 2 for both players\n", | |
"\n", | |
" def normalize(self, strategy):\n", | |
" \"\"\"Normalize strategy. If no positive regrets, strategy is unif. random\"\"\"\n", | |
" if sum(strategy) > 0:\n", | |
" strategy /= sum(strategy)\n", | |
" else:\n", | |
" strategy = np.ones(self.num_actions)/self.num_actions\n", | |
"\n", | |
" return strategy\n", | |
"\n", | |
" def get_strategy(self, reach_probability, active_player=0, p1_strategy):\n", | |
" \"\"\"Return regret matching strategy\"\"\"\n", | |
" ## modified to include pure strategy best response of P2.\n", | |
" if active_player != 0:\n", | |
" # compute best response\n", | |
" strategy = np.array([0., 0.])\n", | |
" br_action = best_response(p1_info_set, p1_T, p1_W)\n", | |
" strategy[br_action] = 1\n", | |
" self.strategy_sum = strategy\n", | |
" else:\n", | |
" strategy = np.maximum(0, self.cumulative_regrets)\n", | |
" strategy = self.normalize(strategy)\n", | |
"\n", | |
" self.strategy_sum += reach_probability * strategy\n", | |
"\n", | |
" return strategy\n", | |
"\n", | |
" def get_average_strategy(self):\n", | |
" return self.normalize(self.strategy_sum.copy())" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 21, | |
"id": "0d4c9090", | |
"metadata": { | |
"id": "0d4c9090" | |
}, | |
"outputs": [], | |
"source": [ | |
"class BeerQuiche():\n", | |
" @staticmethod\n", | |
" def is_terminal(history):\n", | |
" return history in ['Bb', 'Bd', 'Qb', 'Qd']\n", | |
"\n", | |
" @staticmethod\n", | |
" def get_payoff(history, p1_type):\n", | |
" payoff = 0\n", | |
" if p1_type == 'T': # tough\n", | |
" if history == 'Bb':\n", | |
" payoff = 2\n", | |
" elif history == 'Bd':\n", | |
" payoff = 1\n", | |
" elif history == 'Qb':\n", | |
" payoff = 1\n", | |
" else:\n", | |
" payoff = 0\n", | |
" else:\n", | |
" if history == 'Bb':\n", | |
" payoff = -2\n", | |
" elif history == 'Bd':\n", | |
" payoff = 0\n", | |
" elif history == 'Qb':\n", | |
" payoff = -1\n", | |
" else:\n", | |
" payoff = 2\n", | |
"\n", | |
" return payoff" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 22, | |
"id": "b12fe858", | |
"metadata": { | |
"id": "b12fe858" | |
}, | |
"outputs": [], | |
"source": [ | |
"class cfr_trainer():\n", | |
" def __init__(self):\n", | |
" self.infoset_map: Dict[str, info_set] = {}\n", | |
"\n", | |
" def get_information_set(self, type_and_history):\n", | |
" \"\"\"add if needed and return\"\"\"\n", | |
" if type_and_history not in self.infoset_map:\n", | |
" self.infoset_map[type_and_history] = information_set()\n", | |
"\n", | |
" return self.infoset_map[type_and_history]\n", | |
"\n", | |
" def cfr(self, p1_type, history, reach_probabilities, active_player):\n", | |
" if BeerQuiche.is_terminal(history):\n", | |
" return BeerQuiche.get_payoff(history, p1_type)\n", | |
"\n", | |
" if active_player == 0:\n", | |
" info_set = self.get_information_set(p1_type + history)\n", | |
" else:\n", | |
" info_set = self.get_information_set(history)\n", | |
"\n", | |
" strategy = info_set.get_strategy(reach_probabilities[active_player], active_player)\n", | |
"\n", | |
" op = (active_player + 1) % 2\n", | |
" counterfactual_values = np.zeros(2)\n", | |
" if active_player == 0:\n", | |
" actions = p1_actions\n", | |
" else:\n", | |
" actions = p2_actions\n", | |
"\n", | |
" for ix, action in enumerate(actions):\n", | |
" action_probability = strategy[ix]\n", | |
"\n", | |
" new_reach_probabilities = reach_probabilities.copy()\n", | |
" new_reach_probabilities[active_player] *= action_probability\n", | |
"\n", | |
" counterfactual_values[ix] = -self.cfr(p1_type, history + action, new_reach_probabilities, op)\n", | |
"\n", | |
" node_value = counterfactual_values.dot(strategy)\n", | |
" for ix, action in enumerate(actions):\n", | |
" info_set.cumulative_regrets[ix] += reach_probabilities[op] * (counterfactual_values[ix] - node_value)\n", | |
"\n", | |
" return node_value\n", | |
"\n", | |
" def train(self, num_iterations):\n", | |
" util = 0\n", | |
" p = 1/3\n", | |
" types = [1, 2]\n", | |
" for _ in range(num_iterations):\n", | |
" p1_t = np.random.choice(types, p=[p, 1-p])\n", | |
" p1_type = 'T' if p1_t == 1 else 'W'\n", | |
" history = ''\n", | |
" reach_probabilities = np.ones(2)\n", | |
" util += self.cfr(p1_type, history, reach_probabilities, 0)\n", | |
"\n", | |
" return util" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 31, | |
"id": "bd288c1a", | |
"metadata": { | |
"id": "bd288c1a" | |
}, | |
"outputs": [], | |
"source": [ | |
"trainer = cfr_trainer()\n", | |
"num_itrs = 1000\n", | |
"util = trainer.train(num_itrs)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 32, | |
"id": "c595e907", | |
"metadata": { | |
"id": "c595e907", | |
"outputId": "caf4b61e-2ebd-47a9-cd5c-fe7cd2a384c1", | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
} | |
}, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"B: [1. 0.]\n", | |
"Q: [1. 0.]\n", | |
"T: [1. 0.]\n", | |
"W: [0. 1.]\n" | |
] | |
} | |
], | |
"source": [ | |
"for name, info_set in sorted(trainer.infoset_map.items(),):\n", | |
" print(f\"{name}: {np.round(info_set.get_average_strategy(), 2)}\")" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "1c828b71", | |
"metadata": { | |
"id": "1c828b71" | |
}, | |
"source": [ | |
" #### B and Q are the info-sets of P2. T and W are the info-sets of P1." | |
] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3 (ipykernel)", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.8.13" | |
}, | |
"colab": { | |
"provenance": [], | |
"name": "beerquiche_cfr_primal_dual.ipynb", | |
"include_colab_link": true | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 5 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment