Last active
November 26, 2015 04:51
-
-
Save davidshinn/f440c7a1956e9779b287 to your computer and use it in GitHub Desktop.
maze helper for pymdptoolbox
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# -*- coding: utf-8 -*- | |
__author__ = "David Shinn" | |
__license__ = "MIT" | |
from collections import defaultdict | |
import itertools | |
import os | |
import pdb | |
import numpy as np | |
import pandas as pd | |
class Maze(object): | |
def __init__(self, filename, move_probs=None, terminal_positions=None): | |
self.df = pd.read_csv(filename, sep='\s+', | |
header=None, dtype=unicode) | |
self.n_rows, self.n_cols = self.df.shape | |
self.move_probs = move_probs | |
if self.move_probs is None: | |
self.move_probs = {'f': 0.6, 'fl': 0.1, 'fr': 0.1, | |
'l': 0.1, 'r': 0.1} | |
assert sum(self.move_probs.values()) == 1 | |
self._actions_list = ['n', 'ne', 'e', 'se', 's', 'sw', 'w', 'nw'] # left to right list of directions | |
self._move_delta = { | |
'e': (0, 1), | |
'ne': (-1, 1), | |
'n': (-1, 0), | |
'nw': (-1, -1), | |
'w': (0, -1), | |
'sw': (1, -1), | |
's': (1, 0), | |
'se': (1, 1), | |
} | |
self.map_to_symbols = {u'e': u'→', u'ne': u'↗', u'n': u'↑', | |
u'nw': u'↖', u'w': u'←', u'sw': u'↙', | |
u's': u'↓', u'se': u'↘'} | |
self.terminal_positions = terminal_positions | |
if self.terminal_positions is None: | |
self.terminal_positions = [] | |
# Poor mans queue to assist relative movement alignment with compass directions | |
self._ordered_directions = self._actions_list[:] | |
self._ordered_directions.extend(self._actions_list[:]) | |
self._ordered_directions.extend(self._actions_list[:]) | |
self._movements_list = ['l', 'fl', 'f', 'fr', 'r'] | |
self._len_movements = len(self._movements_list) | |
self._n_actions = len(self._actions_list) | |
self.mapping = create_pos_to_state_map(self.df) | |
self.n_states = len(filter(lambda x: isinstance(x, int), self.mapping.keys())) | |
self.reward_matrix = create_reward_matrix( | |
df=self.df, mapping=self.mapping, n_states=self.n_states, | |
n_actions=self._n_actions) | |
self.direction_probs = {} | |
for action in self._actions_list: | |
self.direction_probs[action] = self.get_direction_probs(action) | |
self.transition_matrix = self.create_transition_matrix() | |
def __repr__(self): | |
return repr(self.df) | |
def get_direction_probs(self, action): | |
"""Return dictionary of action: probability based upon action | |
being the forward (f) movement direction and move_probs. | |
Does this using a poorman's queue""" | |
pos = self._ordered_directions.index(action, self._len_movements) # find index somewhere in middle | |
direction_probs = {} | |
for movement, direction in zip(self._movements_list, self._ordered_directions[(pos - 2):(pos - 2 + self._len_movements)]): | |
direction_probs[direction] = self.move_probs[movement] | |
return direction_probs | |
def get_transition_probs_for_single_state(self, row, col): | |
"""Return dict of action keys, dictionary of (row, col) probabilities""" | |
assert self.mapping.get((row, col), None) is not None | |
transition_probs = defaultdict(lambda : defaultdict(float)) | |
for action_intended, direction_probs in self.direction_probs.iteritems(): | |
for action_actual, action_prob in direction_probs.iteritems(): | |
d_row, d_col = self._move_delta[action_actual] | |
new_row = row + d_row | |
new_col = col + d_col | |
if self.mapping.get((new_row, new_col), None) is not None: | |
transition_probs[action_intended][(new_row, new_col)] += action_prob | |
else: | |
transition_probs[action_intended][(row, col)] += action_prob | |
for pos in transition_probs[action_intended]: | |
transition_probs[action_intended][pos] = round(transition_probs[action_intended][pos], 5) | |
return transition_probs | |
def create_transition_matrix(self): | |
transition_matrix = np.zeros((self._n_actions, self.n_states, self.n_states)) | |
for row, col in filter(lambda x: isinstance(x, tuple), self.mapping.keys()): | |
originating_state_index = self.mapping[(row, col)] | |
if originating_state_index is None: | |
continue | |
transition_probs = self.get_transition_probs_for_single_state(row, col) | |
for action_label, state_probs in transition_probs.iteritems(): | |
action_index = self.get_action_index(action_label) | |
# If terminal put 1 where S = S' | |
if (row, col) in self.terminal_positions: | |
transition_matrix[action_index, originating_state_index, originating_state_index] = 1 | |
else: | |
for (new_row, new_col), prob in state_probs.iteritems(): | |
new_state_index = self.mapping[(new_row, new_col)] | |
if new_state_index is None: | |
continue | |
else: | |
transition_matrix[action_index, | |
originating_state_index, | |
new_state_index] = prob | |
return transition_matrix | |
def get_action_index(self, action_label): | |
return self._actions_list.index(action_label) | |
def get_action_label(self, action_index): | |
return self._actions_list[action_index] | |
def create_df_from_values(self, values): | |
"""Return pandas dataframe with values in maze positions and + in wall positions""" | |
assert len(values) == self.n_states | |
df = pd.DataFrame(np.zeros((self.n_rows, self.n_cols))).replace(0, u'+') | |
for state_index, value in enumerate(values): | |
row, col = self.mapping[state_index] | |
df.iloc[row, col] = value | |
return df | |
def create_policy_visual_from_values(self, policy_values): | |
"""Return pandas dataframe with arrows for policies""" | |
df = self.create_df_from_values(policy_values) | |
for action_index, action_label in enumerate(self._actions_list): | |
df.replace(action_index, self.map_to_symbols[action_label], inplace=True) | |
for row, col in self.terminal_positions: | |
#df.iloc[row, col] = u'⚑' | |
df.iloc[row, col] = u'⨂' | |
df.replace(u'+', u'█', inplace=True) | |
return df | |
def create_pos_to_state_map(df): | |
"""Return dict of maps from state indices (int) to maze | |
positions (tuple) and vice versa. Wall positions are mapped to | |
None and None's value is a list of all wall position tuples. | |
Need to do this so that there are no states representing the | |
wall positions and wasting the MDP algorithms.""" | |
mapping = {None: []} | |
state_index = 0 | |
for row, col in itertools.product(range(df.shape[0]), range(df.shape[1])): | |
if df.iloc[row, col] == '+': | |
mapping[None].append((row, col)) | |
mapping[(row, col)] = None | |
else: | |
mapping[state_index] = (row, col) | |
mapping[(row, col)] = state_index | |
state_index += 1 | |
return mapping | |
def create_reward_matrix(df, mapping, n_states, n_actions): | |
"""Returns an A x S x S' matrix, where each S x S' matrix is identical, where | |
S == S' is left to zero and only rewards are provided by transition from neighboring | |
states""" | |
state_matrix = np.zeros((n_states, n_states)) | |
for state_index in range(n_states): | |
row, col = mapping[state_index] | |
reward = df.iloc[row, col] | |
# Iterate through possible neighboring positions | |
for d_row, d_col in itertools.product([-1, 0, 1], [-1, 0, 1]): | |
# Skip no offset | |
if d_row == 0 and d_col == 0: | |
continue | |
originating_row = row + d_row | |
originating_col = col + d_col | |
# Only populate if neighbor is valid position | |
key = (originating_row, originating_col) | |
if key in mapping and mapping[key] is not None: | |
originating_state_index = mapping[(originating_row, originating_col)] | |
state_matrix[originating_state_index, state_index] = reward | |
reward_matrix = np.zeros((n_actions, n_states, n_states)) | |
for n_action in range(n_actions): | |
reward_matrix[n_action, :, :] = state_matrix.copy() | |
return reward_matrix |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"import Maze\n", | |
"import pandas as pd\n", | |
"import numpy as np\n", | |
"import mdptoolbox" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"+ + + + + +\r\n", | |
"0 0 0 0 + 0\r\n", | |
"+ 0 + + + 0\r\n", | |
"+ 0 + 0 0 0\r\n", | |
"+ 0 + 0 + +\r\n", | |
"+ 0 0 0 + 0\r\n", | |
"+ 0 + + + 0\r\n", | |
"+ 0 0 0 0 200\r\n" | |
] | |
} | |
], | |
"source": [ | |
"!cat maze_small.txt" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"maze = Maze.Maze(filename='maze_small.txt', terminal_positions=[(7, 5)])" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/html": [ | |
"<div>\n", | |
"<table border=\"1\" class=\"dataframe\">\n", | |
" <thead>\n", | |
" <tr style=\"text-align: right;\">\n", | |
" <th></th>\n", | |
" <th>0</th>\n", | |
" <th>1</th>\n", | |
" <th>2</th>\n", | |
" <th>3</th>\n", | |
" <th>4</th>\n", | |
" <th>5</th>\n", | |
" </tr>\n", | |
" </thead>\n", | |
" <tbody>\n", | |
" <tr>\n", | |
" <th>0</th>\n", | |
" <td>+</td>\n", | |
" <td>+</td>\n", | |
" <td>+</td>\n", | |
" <td>+</td>\n", | |
" <td>+</td>\n", | |
" <td>+</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>1</th>\n", | |
" <td>0</td>\n", | |
" <td>0</td>\n", | |
" <td>0</td>\n", | |
" <td>0</td>\n", | |
" <td>+</td>\n", | |
" <td>0</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>2</th>\n", | |
" <td>+</td>\n", | |
" <td>0</td>\n", | |
" <td>+</td>\n", | |
" <td>+</td>\n", | |
" <td>+</td>\n", | |
" <td>0</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>3</th>\n", | |
" <td>+</td>\n", | |
" <td>0</td>\n", | |
" <td>+</td>\n", | |
" <td>0</td>\n", | |
" <td>0</td>\n", | |
" <td>0</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>4</th>\n", | |
" <td>+</td>\n", | |
" <td>0</td>\n", | |
" <td>+</td>\n", | |
" <td>0</td>\n", | |
" <td>+</td>\n", | |
" <td>+</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>5</th>\n", | |
" <td>+</td>\n", | |
" <td>0</td>\n", | |
" <td>0</td>\n", | |
" <td>0</td>\n", | |
" <td>+</td>\n", | |
" <td>0</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>6</th>\n", | |
" <td>+</td>\n", | |
" <td>0</td>\n", | |
" <td>+</td>\n", | |
" <td>+</td>\n", | |
" <td>+</td>\n", | |
" <td>0</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>7</th>\n", | |
" <td>+</td>\n", | |
" <td>0</td>\n", | |
" <td>0</td>\n", | |
" <td>0</td>\n", | |
" <td>0</td>\n", | |
" <td>200</td>\n", | |
" </tr>\n", | |
" </tbody>\n", | |
"</table>\n", | |
"</div>" | |
], | |
"text/plain": [ | |
" 0 1 2 3 4 5\n", | |
"0 + + + + + +\n", | |
"1 0 0 0 0 + 0\n", | |
"2 + 0 + + + 0\n", | |
"3 + 0 + 0 0 0\n", | |
"4 + 0 + 0 + +\n", | |
"5 + 0 0 0 + 0\n", | |
"6 + 0 + + + 0\n", | |
"7 + 0 0 0 0 200" | |
] | |
}, | |
"execution_count": 4, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"maze.df" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
" Iteration Variation\n", | |
" 1 120.000000\n", | |
" 2 64.800000\n", | |
" 3 46.656000\n", | |
" 4 37.791360\n", | |
" 5 27.209779\n", | |
" 6 24.335746\n", | |
" 7 18.430884\n", | |
" 8 15.486902\n", | |
" 9 13.732017\n", | |
" 10 11.364422\n", | |
" 11 9.963593\n", | |
" 12 8.592304\n", | |
" 13 7.306030\n", | |
" 14 6.486740\n", | |
" 15 5.510174\n", | |
" 16 4.898826\n", | |
" 17 4.012067\n", | |
" 18 3.077883\n", | |
" 19 2.239239\n", | |
" 20 1.559553\n", | |
" 21 1.047495\n", | |
" 22 0.682523\n", | |
" 23 0.433492\n", | |
" 24 0.269442\n", | |
" 25 0.164441\n", | |
" 26 0.098816\n", | |
" 27 0.058607\n", | |
" 28 0.034377\n", | |
" 29 0.019976\n", | |
" 30 0.011517\n", | |
" 31 0.006597\n", | |
" 32 0.003758\n", | |
" 33 0.002131\n", | |
" 34 0.001204\n", | |
" 35 0.000678\n", | |
"Iterating stopped, epsilon-optimal policy found.\n" | |
] | |
} | |
], | |
"source": [ | |
"value_iteration = mdptoolbox.mdp.ValueIteration(\n", | |
" transitions=maze.transition_matrix,\n", | |
" reward=maze.reward_matrix,\n", | |
" discount=0.90)\n", | |
"value_iteration.setVerbose()\n", | |
"value_iteration.run()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/html": [ | |
"<div>\n", | |
"<table border=\"1\" class=\"dataframe\">\n", | |
" <thead>\n", | |
" <tr style=\"text-align: right;\">\n", | |
" <th></th>\n", | |
" <th>0</th>\n", | |
" <th>1</th>\n", | |
" <th>2</th>\n", | |
" <th>3</th>\n", | |
" <th>4</th>\n", | |
" <th>5</th>\n", | |
" </tr>\n", | |
" </thead>\n", | |
" <tbody>\n", | |
" <tr>\n", | |
" <th>0</th>\n", | |
" <td>█</td>\n", | |
" <td>█</td>\n", | |
" <td>█</td>\n", | |
" <td>█</td>\n", | |
" <td>█</td>\n", | |
" <td>█</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>1</th>\n", | |
" <td>↘</td>\n", | |
" <td>↓</td>\n", | |
" <td>↙</td>\n", | |
" <td>←</td>\n", | |
" <td>█</td>\n", | |
" <td>↓</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>2</th>\n", | |
" <td>█</td>\n", | |
" <td>↓</td>\n", | |
" <td>█</td>\n", | |
" <td>█</td>\n", | |
" <td>█</td>\n", | |
" <td>↙</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>3</th>\n", | |
" <td>█</td>\n", | |
" <td>↓</td>\n", | |
" <td>█</td>\n", | |
" <td>↓</td>\n", | |
" <td>↙</td>\n", | |
" <td>←</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>4</th>\n", | |
" <td>█</td>\n", | |
" <td>↓</td>\n", | |
" <td>█</td>\n", | |
" <td>↙</td>\n", | |
" <td>█</td>\n", | |
" <td>█</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>5</th>\n", | |
" <td>█</td>\n", | |
" <td>↓</td>\n", | |
" <td>↙</td>\n", | |
" <td>←</td>\n", | |
" <td>█</td>\n", | |
" <td>↓</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>6</th>\n", | |
" <td>█</td>\n", | |
" <td>↘</td>\n", | |
" <td>█</td>\n", | |
" <td>█</td>\n", | |
" <td>█</td>\n", | |
" <td>↓</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>7</th>\n", | |
" <td>█</td>\n", | |
" <td>→</td>\n", | |
" <td>→</td>\n", | |
" <td>→</td>\n", | |
" <td>→</td>\n", | |
" <td>⨂</td>\n", | |
" </tr>\n", | |
" </tbody>\n", | |
"</table>\n", | |
"</div>" | |
], | |
"text/plain": [ | |
" 0 1 2 3 4 5\n", | |
"0 █ █ █ █ █ █\n", | |
"1 ↘ ↓ ↙ ← █ ↓\n", | |
"2 █ ↓ █ █ █ ↙\n", | |
"3 █ ↓ █ ↓ ↙ ←\n", | |
"4 █ ↓ █ ↙ █ █\n", | |
"5 █ ↓ ↙ ← █ ↓\n", | |
"6 █ ↘ █ █ █ ↓\n", | |
"7 █ → → → → ⨂" | |
] | |
}, | |
"execution_count": 6, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"maze.create_policy_visual_from_values(value_iteration.policy)" | |
] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 2", | |
"language": "python", | |
"name": "python2" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 2 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython2", | |
"version": "2.7.10" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 0 | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
+ + + + + + | |
0 0 0 0 + 0 | |
+ 0 + + + 0 | |
+ 0 + 0 0 0 | |
+ 0 + 0 + + | |
+ 0 0 0 + 0 | |
+ 0 + + + 0 | |
+ 0 0 0 0 200 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
This is really great David!
Thanks for sharing this.