Created
March 28, 2017 22:38
-
-
Save eadanfahey/ef5fef7cfc31a1688421f56a381afc6b to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 86, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"import numpy as np\n", | |
"import itertools\n", | |
"import math\n", | |
"from operator import itemgetter" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 89, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"def sort_reference(pairs, reference):\n", | |
" \"\"\"\n", | |
" Sort a list of pairs according to a reference list.\n", | |
" Parameters:\n", | |
" pairs: (iterable[tuple]) a collection of (key, value) pairs.\n", | |
" reference: (list) the reference list to sort by.\n", | |
" Returns:\n", | |
" (list[tuple]) the list of pairs sorted according\n", | |
" to `reference`.\n", | |
" Example:\n", | |
" pairs = [('B', 8), ('G', 3), ('R', 7)]\n", | |
" reference = ['R', 'G', 'B]\n", | |
" sort_reference(pairs, reference) == [('R', 7), ('G', 3), ('B', 8)]\n", | |
" \"\"\"\n", | |
" \n", | |
" n = len(reference)\n", | |
" ref_dict = dict(zip(reference, range(n)))\n", | |
" pairs_ordering = [(p, ref_dict[p[0]]) for p in pairs]\n", | |
" return [_[0] for _ in sorted(pairs_ordering, key=itemgetter(-1))]\n" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 105, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"def marginal_value(player_permutation, coalition_values):\n", | |
" \"\"\"\n", | |
" Calculate the marginal value for each player in a given permutation.\n", | |
" Parameters:\n", | |
" player_permutation: (tuple) of player ids.\n", | |
" coalition_values: (dict[tuple]: real) map between player coalitions and value.\n", | |
" Returns:\n", | |
" (tuple) the marginal value for each player. This tuple is in\n", | |
" the same order as the `player_permutation`, i.e. the first marginal value\n", | |
" corresponds to first player.\n", | |
" \"\"\"\n", | |
" n = len(player_permutation)\n", | |
" marginal_val = [0] * n\n", | |
" \n", | |
" accumulated_val = 0\n", | |
" \n", | |
" for i in range(n):\n", | |
" subset = player_permutation[:i+1]\n", | |
" coalition_val = coalition_values[tuple(sorted(subset))]\n", | |
" \n", | |
" incremental_val = coalition_val - accumulated_val\n", | |
" marginal_val[i] = incremental_val\n", | |
" accumulated_val += incremental_val\n", | |
" \n", | |
" return marginal_val\n", | |
" " | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 107, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [], | |
"source": [ | |
"def shapley_value(players, coalition_values):\n", | |
" \"\"\"\n", | |
" Calculate the shapley value for a given set of players.\n", | |
" Parameters:\n", | |
" players: (tuple) the set of players.\n", | |
" coalition_values: (dict[tuple]: real) the outcome value when a\n", | |
" subsets of the players play together.\n", | |
" Returns:\n", | |
" (tuple) the shapley value for each player.\n", | |
" \"\"\"\n", | |
" \n", | |
" if len(players) == 0:\n", | |
" return (0,)\n", | |
" \n", | |
" n = len(players)\n", | |
" n_permutations = math.factorial(n)\n", | |
" sort_players = sorted(players)\n", | |
" total_marginal_val = np.array([0] * n)\n", | |
" \n", | |
" for pp in itertools.permutations(players, n):\n", | |
" marginal_val = marginal_value(pp, coalition_values)\n", | |
" \n", | |
" # order the marginal values by player id\n", | |
" ordered_mv = [_[1] for _ in sorted(zip(pp, marginal_val))]\n", | |
" total_marginal_val += np.array(ordered_mv)\n", | |
" \n", | |
" # these are ordered according to `sort_players`\n", | |
" shapley_val = total_marginal_val / n_permutations\n", | |
" \n", | |
" # order the shapley values according to `players`\n", | |
" ordered_shapley_val = sort_reference(zip(sort_players, shapley_val), players)\n", | |
" \n", | |
" return ordered_shapley_val\n", | |
" \n", | |
" " | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## Example\n", | |
"\n", | |
"A game with three players: `R`, `G` and `B`" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 108, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"[('R', 7.666666666666667),\n", | |
" ('G', 3.1666666666666665),\n", | |
" ('B', 8.1666666666666661)]" | |
] | |
}, | |
"execution_count": 108, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"coalition_values = {\n", | |
" ('G',): 4,\n", | |
" ('R',): 7,\n", | |
" ('B',): 6,\n", | |
" ('G', 'R'): 7,\n", | |
" ('B', 'R'): 15,\n", | |
" ('B', 'G'): 9,\n", | |
" ('B', 'G', 'R'): 19\n", | |
"}\n", | |
"\n", | |
"shapley_value(('R', 'G', 'B'), coalition_values)" | |
] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.5.2" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment