Skip to content

Instantly share code, notes, and snippets.

@ghimiremukesh
Last active April 10, 2024 21:38
Show Gist options
  • Save ghimiremukesh/ff5c44fdb53fb4892fb55703a7c5d891 to your computer and use it in GitHub Desktop.
Save ghimiremukesh/ff5c44fdb53fb4892fb55703a7c5d891 to your computer and use it in GitHub Desktop.
beerquiche_cfr_primal_dual.ipynb
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/ghimiremukesh/ff5c44fdb53fb4892fb55703a7c5d891/beerquiche_cfr.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "markdown",
"id": "c98ba361",
"metadata": {
"id": "c98ba361"
},
"source": [
"### CFR -- Beer Quiche Game"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "43d052ee",
"metadata": {
"id": "43d052ee"
},
"outputs": [],
"source": [
"import numpy as np\n",
"from typing import List, Dict\n",
"import random\n",
"import sys"
]
},
{
"cell_type": "code",
"source": [
"def best_response(info_set, p1_T, p1_W):\n",
" \"\"\"Compute best response action for P2 (from P1's perspective) given P1's strategy\"\"\"\n",
" if info_set == 'Q':\n",
" Q = p1_T\n",
" q = p1_W\n",
"\n",
" e_bully = - (Q - 2 * q)/(Q+2 * q)\n",
" e_defer = -(4 * q)/(Q + + 2 * q)\n",
" else:\n",
" B = p1_T\n",
" b = p1_W\n",
"\n",
" e_bully = (4*b - 2 * B)/(B + 2 * b)\n",
" e_defer = -B/(B + 2 * b)\n",
"\n",
" evs = np.array([e_b, e_d])\n",
"\n",
" return np.argmax(evs)"
],
"metadata": {
"id": "W137Brd962OW"
},
"id": "W137Brd962OW",
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"execution_count": 20,
"id": "a7699947",
"metadata": {
"id": "a7699947"
},
"outputs": [],
"source": [
"p1_actions = ['B', 'Q']\n",
"p2_actions = ['b', 'd']\n",
"\n",
"class information_set():\n",
" def __init__(self):\n",
" self.cumulative_regrets = np.zeros(shape=2)\n",
" self.strategy_sum = np.zeros(shape=2)\n",
" self.num_actions = 2 # 2 for both players\n",
"\n",
" def normalize(self, strategy):\n",
" \"\"\"Normalize strategy. If no positive regrets, strategy is unif. random\"\"\"\n",
" if sum(strategy) > 0:\n",
" strategy /= sum(strategy)\n",
" else:\n",
" strategy = np.ones(self.num_actions)/self.num_actions\n",
"\n",
" return strategy\n",
"\n",
" def get_strategy(self, reach_probability, active_player=0, p1_strategy):\n",
" \"\"\"Return regret matching strategy\"\"\"\n",
" ## modified to include pure strategy best response of P2.\n",
" if active_player != 0:\n",
" # compute best response\n",
" strategy = np.array([0., 0.])\n",
" br_action = best_response(p1_info_set, p1_T, p1_W)\n",
" strategy[br_action] = 1\n",
" self.strategy_sum = strategy\n",
" else:\n",
" strategy = np.maximum(0, self.cumulative_regrets)\n",
" strategy = self.normalize(strategy)\n",
"\n",
" self.strategy_sum += reach_probability * strategy\n",
"\n",
" return strategy\n",
"\n",
" def get_average_strategy(self):\n",
" return self.normalize(self.strategy_sum.copy())"
]
},
{
"cell_type": "code",
"execution_count": 21,
"id": "0d4c9090",
"metadata": {
"id": "0d4c9090"
},
"outputs": [],
"source": [
"class BeerQuiche():\n",
" @staticmethod\n",
" def is_terminal(history):\n",
" return history in ['Bb', 'Bd', 'Qb', 'Qd']\n",
"\n",
" @staticmethod\n",
" def get_payoff(history, p1_type):\n",
" payoff = 0\n",
" if p1_type == 'T': # tough\n",
" if history == 'Bb':\n",
" payoff = 2\n",
" elif history == 'Bd':\n",
" payoff = 1\n",
" elif history == 'Qb':\n",
" payoff = 1\n",
" else:\n",
" payoff = 0\n",
" else:\n",
" if history == 'Bb':\n",
" payoff = -2\n",
" elif history == 'Bd':\n",
" payoff = 0\n",
" elif history == 'Qb':\n",
" payoff = -1\n",
" else:\n",
" payoff = 2\n",
"\n",
" return payoff"
]
},
{
"cell_type": "code",
"execution_count": 22,
"id": "b12fe858",
"metadata": {
"id": "b12fe858"
},
"outputs": [],
"source": [
"class cfr_trainer():\n",
" def __init__(self):\n",
" self.infoset_map: Dict[str, info_set] = {}\n",
"\n",
" def get_information_set(self, type_and_history):\n",
" \"\"\"add if needed and return\"\"\"\n",
" if type_and_history not in self.infoset_map:\n",
" self.infoset_map[type_and_history] = information_set()\n",
"\n",
" return self.infoset_map[type_and_history]\n",
"\n",
" def cfr(self, p1_type, history, reach_probabilities, active_player):\n",
" if BeerQuiche.is_terminal(history):\n",
" return BeerQuiche.get_payoff(history, p1_type)\n",
"\n",
" if active_player == 0:\n",
" info_set = self.get_information_set(p1_type + history)\n",
" else:\n",
" info_set = self.get_information_set(history)\n",
"\n",
" strategy = info_set.get_strategy(reach_probabilities[active_player], active_player)\n",
"\n",
" op = (active_player + 1) % 2\n",
" counterfactual_values = np.zeros(2)\n",
" if active_player == 0:\n",
" actions = p1_actions\n",
" else:\n",
" actions = p2_actions\n",
"\n",
" for ix, action in enumerate(actions):\n",
" action_probability = strategy[ix]\n",
"\n",
" new_reach_probabilities = reach_probabilities.copy()\n",
" new_reach_probabilities[active_player] *= action_probability\n",
"\n",
" counterfactual_values[ix] = -self.cfr(p1_type, history + action, new_reach_probabilities, op)\n",
"\n",
" node_value = counterfactual_values.dot(strategy)\n",
" for ix, action in enumerate(actions):\n",
" info_set.cumulative_regrets[ix] += reach_probabilities[op] * (counterfactual_values[ix] - node_value)\n",
"\n",
" return node_value\n",
"\n",
" def train(self, num_iterations):\n",
" util = 0\n",
" p = 1/3\n",
" types = [1, 2]\n",
" for _ in range(num_iterations):\n",
" p1_t = np.random.choice(types, p=[p, 1-p])\n",
" p1_type = 'T' if p1_t == 1 else 'W'\n",
" history = ''\n",
" reach_probabilities = np.ones(2)\n",
" util += self.cfr(p1_type, history, reach_probabilities, 0)\n",
"\n",
" return util"
]
},
{
"cell_type": "code",
"execution_count": 31,
"id": "bd288c1a",
"metadata": {
"id": "bd288c1a"
},
"outputs": [],
"source": [
"trainer = cfr_trainer()\n",
"num_itrs = 1000\n",
"util = trainer.train(num_itrs)"
]
},
{
"cell_type": "code",
"execution_count": 32,
"id": "c595e907",
"metadata": {
"id": "c595e907",
"outputId": "caf4b61e-2ebd-47a9-cd5c-fe7cd2a384c1",
"colab": {
"base_uri": "https://localhost:8080/"
}
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"B: [1. 0.]\n",
"Q: [1. 0.]\n",
"T: [1. 0.]\n",
"W: [0. 1.]\n"
]
}
],
"source": [
"for name, info_set in sorted(trainer.infoset_map.items(),):\n",
" print(f\"{name}: {np.round(info_set.get_average_strategy(), 2)}\")"
]
},
{
"cell_type": "markdown",
"id": "1c828b71",
"metadata": {
"id": "1c828b71"
},
"source": [
" #### B and Q are the info-sets of P2. T and W are the info-sets of P1."
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.13"
},
"colab": {
"provenance": [],
"name": "beerquiche_cfr_primal_dual.ipynb",
"include_colab_link": true
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment