Skip to content

Instantly share code, notes, and snippets.

@davidshinn
Last active November 26, 2015 04:51
Show Gist options
  • Save davidshinn/f440c7a1956e9779b287 to your computer and use it in GitHub Desktop.
Save davidshinn/f440c7a1956e9779b287 to your computer and use it in GitHub Desktop.
maze helper for pymdptoolbox
# -*- coding: utf-8 -*-
__author__ = "David Shinn"
__license__ = "MIT"
from collections import defaultdict
import itertools
import os
import pdb
import numpy as np
import pandas as pd
class Maze(object):
def __init__(self, filename, move_probs=None, terminal_positions=None):
self.df = pd.read_csv(filename, sep='\s+',
header=None, dtype=unicode)
self.n_rows, self.n_cols = self.df.shape
self.move_probs = move_probs
if self.move_probs is None:
self.move_probs = {'f': 0.6, 'fl': 0.1, 'fr': 0.1,
'l': 0.1, 'r': 0.1}
assert sum(self.move_probs.values()) == 1
self._actions_list = ['n', 'ne', 'e', 'se', 's', 'sw', 'w', 'nw'] # left to right list of directions
self._move_delta = {
'e': (0, 1),
'ne': (-1, 1),
'n': (-1, 0),
'nw': (-1, -1),
'w': (0, -1),
'sw': (1, -1),
's': (1, 0),
'se': (1, 1),
}
self.map_to_symbols = {u'e': u'→', u'ne': u'↗', u'n': u'↑',
u'nw': u'↖', u'w': u'←', u'sw': u'↙',
u's': u'↓', u'se': u'↘'}
self.terminal_positions = terminal_positions
if self.terminal_positions is None:
self.terminal_positions = []
# Poor mans queue to assist relative movement alignment with compass directions
self._ordered_directions = self._actions_list[:]
self._ordered_directions.extend(self._actions_list[:])
self._ordered_directions.extend(self._actions_list[:])
self._movements_list = ['l', 'fl', 'f', 'fr', 'r']
self._len_movements = len(self._movements_list)
self._n_actions = len(self._actions_list)
self.mapping = create_pos_to_state_map(self.df)
self.n_states = len(filter(lambda x: isinstance(x, int), self.mapping.keys()))
self.reward_matrix = create_reward_matrix(
df=self.df, mapping=self.mapping, n_states=self.n_states,
n_actions=self._n_actions)
self.direction_probs = {}
for action in self._actions_list:
self.direction_probs[action] = self.get_direction_probs(action)
self.transition_matrix = self.create_transition_matrix()
def __repr__(self):
return repr(self.df)
def get_direction_probs(self, action):
"""Return dictionary of action: probability based upon action
being the forward (f) movement direction and move_probs.
Does this using a poorman's queue"""
pos = self._ordered_directions.index(action, self._len_movements) # find index somewhere in middle
direction_probs = {}
for movement, direction in zip(self._movements_list, self._ordered_directions[(pos - 2):(pos - 2 + self._len_movements)]):
direction_probs[direction] = self.move_probs[movement]
return direction_probs
def get_transition_probs_for_single_state(self, row, col):
"""Return dict of action keys, dictionary of (row, col) probabilities"""
assert self.mapping.get((row, col), None) is not None
transition_probs = defaultdict(lambda : defaultdict(float))
for action_intended, direction_probs in self.direction_probs.iteritems():
for action_actual, action_prob in direction_probs.iteritems():
d_row, d_col = self._move_delta[action_actual]
new_row = row + d_row
new_col = col + d_col
if self.mapping.get((new_row, new_col), None) is not None:
transition_probs[action_intended][(new_row, new_col)] += action_prob
else:
transition_probs[action_intended][(row, col)] += action_prob
for pos in transition_probs[action_intended]:
transition_probs[action_intended][pos] = round(transition_probs[action_intended][pos], 5)
return transition_probs
def create_transition_matrix(self):
transition_matrix = np.zeros((self._n_actions, self.n_states, self.n_states))
for row, col in filter(lambda x: isinstance(x, tuple), self.mapping.keys()):
originating_state_index = self.mapping[(row, col)]
if originating_state_index is None:
continue
transition_probs = self.get_transition_probs_for_single_state(row, col)
for action_label, state_probs in transition_probs.iteritems():
action_index = self.get_action_index(action_label)
# If terminal put 1 where S = S'
if (row, col) in self.terminal_positions:
transition_matrix[action_index, originating_state_index, originating_state_index] = 1
else:
for (new_row, new_col), prob in state_probs.iteritems():
new_state_index = self.mapping[(new_row, new_col)]
if new_state_index is None:
continue
else:
transition_matrix[action_index,
originating_state_index,
new_state_index] = prob
return transition_matrix
def get_action_index(self, action_label):
return self._actions_list.index(action_label)
def get_action_label(self, action_index):
return self._actions_list[action_index]
def create_df_from_values(self, values):
"""Return pandas dataframe with values in maze positions and + in wall positions"""
assert len(values) == self.n_states
df = pd.DataFrame(np.zeros((self.n_rows, self.n_cols))).replace(0, u'+')
for state_index, value in enumerate(values):
row, col = self.mapping[state_index]
df.iloc[row, col] = value
return df
def create_policy_visual_from_values(self, policy_values):
"""Return pandas dataframe with arrows for policies"""
df = self.create_df_from_values(policy_values)
for action_index, action_label in enumerate(self._actions_list):
df.replace(action_index, self.map_to_symbols[action_label], inplace=True)
for row, col in self.terminal_positions:
#df.iloc[row, col] = u'⚑'
df.iloc[row, col] = u'⨂'
df.replace(u'+', u'█', inplace=True)
return df
def create_pos_to_state_map(df):
"""Return dict of maps from state indices (int) to maze
positions (tuple) and vice versa. Wall positions are mapped to
None and None's value is a list of all wall position tuples.
Need to do this so that there are no states representing the
wall positions and wasting the MDP algorithms."""
mapping = {None: []}
state_index = 0
for row, col in itertools.product(range(df.shape[0]), range(df.shape[1])):
if df.iloc[row, col] == '+':
mapping[None].append((row, col))
mapping[(row, col)] = None
else:
mapping[state_index] = (row, col)
mapping[(row, col)] = state_index
state_index += 1
return mapping
def create_reward_matrix(df, mapping, n_states, n_actions):
"""Returns an A x S x S' matrix, where each S x S' matrix is identical, where
S == S' is left to zero and only rewards are provided by transition from neighboring
states"""
state_matrix = np.zeros((n_states, n_states))
for state_index in range(n_states):
row, col = mapping[state_index]
reward = df.iloc[row, col]
# Iterate through possible neighboring positions
for d_row, d_col in itertools.product([-1, 0, 1], [-1, 0, 1]):
# Skip no offset
if d_row == 0 and d_col == 0:
continue
originating_row = row + d_row
originating_col = col + d_col
# Only populate if neighbor is valid position
key = (originating_row, originating_col)
if key in mapping and mapping[key] is not None:
originating_state_index = mapping[(originating_row, originating_col)]
state_matrix[originating_state_index, state_index] = reward
reward_matrix = np.zeros((n_actions, n_states, n_states))
for n_action in range(n_actions):
reward_matrix[n_action, :, :] = state_matrix.copy()
return reward_matrix
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"import Maze\n",
"import pandas as pd\n",
"import numpy as np\n",
"import mdptoolbox"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"+ + + + + +\r\n",
"0 0 0 0 + 0\r\n",
"+ 0 + + + 0\r\n",
"+ 0 + 0 0 0\r\n",
"+ 0 + 0 + +\r\n",
"+ 0 0 0 + 0\r\n",
"+ 0 + + + 0\r\n",
"+ 0 0 0 0 200\r\n"
]
}
],
"source": [
"!cat maze_small.txt"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"maze = Maze.Maze(filename='maze_small.txt', terminal_positions=[(7, 5)])"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"collapsed": false
},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>0</th>\n",
" <th>1</th>\n",
" <th>2</th>\n",
" <th>3</th>\n",
" <th>4</th>\n",
" <th>5</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>+</td>\n",
" <td>+</td>\n",
" <td>+</td>\n",
" <td>+</td>\n",
" <td>+</td>\n",
" <td>+</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>+</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>+</td>\n",
" <td>0</td>\n",
" <td>+</td>\n",
" <td>+</td>\n",
" <td>+</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>+</td>\n",
" <td>0</td>\n",
" <td>+</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>+</td>\n",
" <td>0</td>\n",
" <td>+</td>\n",
" <td>0</td>\n",
" <td>+</td>\n",
" <td>+</td>\n",
" </tr>\n",
" <tr>\n",
" <th>5</th>\n",
" <td>+</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>+</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>6</th>\n",
" <td>+</td>\n",
" <td>0</td>\n",
" <td>+</td>\n",
" <td>+</td>\n",
" <td>+</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>7</th>\n",
" <td>+</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>200</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" 0 1 2 3 4 5\n",
"0 + + + + + +\n",
"1 0 0 0 0 + 0\n",
"2 + 0 + + + 0\n",
"3 + 0 + 0 0 0\n",
"4 + 0 + 0 + +\n",
"5 + 0 0 0 + 0\n",
"6 + 0 + + + 0\n",
"7 + 0 0 0 0 200"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"maze.df"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" Iteration Variation\n",
" 1 120.000000\n",
" 2 64.800000\n",
" 3 46.656000\n",
" 4 37.791360\n",
" 5 27.209779\n",
" 6 24.335746\n",
" 7 18.430884\n",
" 8 15.486902\n",
" 9 13.732017\n",
" 10 11.364422\n",
" 11 9.963593\n",
" 12 8.592304\n",
" 13 7.306030\n",
" 14 6.486740\n",
" 15 5.510174\n",
" 16 4.898826\n",
" 17 4.012067\n",
" 18 3.077883\n",
" 19 2.239239\n",
" 20 1.559553\n",
" 21 1.047495\n",
" 22 0.682523\n",
" 23 0.433492\n",
" 24 0.269442\n",
" 25 0.164441\n",
" 26 0.098816\n",
" 27 0.058607\n",
" 28 0.034377\n",
" 29 0.019976\n",
" 30 0.011517\n",
" 31 0.006597\n",
" 32 0.003758\n",
" 33 0.002131\n",
" 34 0.001204\n",
" 35 0.000678\n",
"Iterating stopped, epsilon-optimal policy found.\n"
]
}
],
"source": [
"value_iteration = mdptoolbox.mdp.ValueIteration(\n",
" transitions=maze.transition_matrix,\n",
" reward=maze.reward_matrix,\n",
" discount=0.90)\n",
"value_iteration.setVerbose()\n",
"value_iteration.run()"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"collapsed": false
},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>0</th>\n",
" <th>1</th>\n",
" <th>2</th>\n",
" <th>3</th>\n",
" <th>4</th>\n",
" <th>5</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>█</td>\n",
" <td>█</td>\n",
" <td>█</td>\n",
" <td>█</td>\n",
" <td>█</td>\n",
" <td>█</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>↘</td>\n",
" <td>↓</td>\n",
" <td>↙</td>\n",
" <td>←</td>\n",
" <td>█</td>\n",
" <td>↓</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>█</td>\n",
" <td>↓</td>\n",
" <td>█</td>\n",
" <td>█</td>\n",
" <td>█</td>\n",
" <td>↙</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>█</td>\n",
" <td>↓</td>\n",
" <td>█</td>\n",
" <td>↓</td>\n",
" <td>↙</td>\n",
" <td>←</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>█</td>\n",
" <td>↓</td>\n",
" <td>█</td>\n",
" <td>↙</td>\n",
" <td>█</td>\n",
" <td>█</td>\n",
" </tr>\n",
" <tr>\n",
" <th>5</th>\n",
" <td>█</td>\n",
" <td>↓</td>\n",
" <td>↙</td>\n",
" <td>←</td>\n",
" <td>█</td>\n",
" <td>↓</td>\n",
" </tr>\n",
" <tr>\n",
" <th>6</th>\n",
" <td>█</td>\n",
" <td>↘</td>\n",
" <td>█</td>\n",
" <td>█</td>\n",
" <td>█</td>\n",
" <td>↓</td>\n",
" </tr>\n",
" <tr>\n",
" <th>7</th>\n",
" <td>█</td>\n",
" <td>→</td>\n",
" <td>→</td>\n",
" <td>→</td>\n",
" <td>→</td>\n",
" <td>⨂</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" 0 1 2 3 4 5\n",
"0 █ █ █ █ █ █\n",
"1 ↘ ↓ ↙ ← █ ↓\n",
"2 █ ↓ █ █ █ ↙\n",
"3 █ ↓ █ ↓ ↙ ←\n",
"4 █ ↓ █ ↙ █ █\n",
"5 █ ↓ ↙ ← █ ↓\n",
"6 █ ↘ █ █ █ ↓\n",
"7 █ → → → → ⨂"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"maze.create_policy_visual_from_values(value_iteration.policy)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 2",
"language": "python",
"name": "python2"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 2
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2",
"version": "2.7.10"
}
},
"nbformat": 4,
"nbformat_minor": 0
}
+ + + + + +
0 0 0 0 + 0
+ 0 + + + 0
+ 0 + 0 0 0
+ 0 + 0 + +
+ 0 0 0 + 0
+ 0 + + + 0
+ 0 0 0 0 200
@ajwije
Copy link

ajwije commented Nov 21, 2015

This is really great David!
Thanks for sharing this.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment