Created
October 26, 2012 07:37
-
-
Save pckujawa/3957444 to your computer and use it in GitHub Desktop.
Active and passive temporal difference learning in grid world for AI
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#------------------------------------------------------------------------------- | |
# Author: Pat Kujawa | |
#------------------------------------------------------------------------------- | |
#!/usr/bin/env python | |
import numpy as np | |
import random | |
from collections import defaultdict | |
from pprint import pprint, pformat | |
class Log(): | |
@classmethod | |
def log(self,msg): | |
print msg | |
@classmethod | |
def i(self,msg): | |
return | |
self.log('INFO: {}'.format(msg)) | |
class obj(object): | |
def __init__(self, **kwargs): | |
'''Create obj with attributes/values as in kwargs dict''' | |
self.__dict__.update(kwargs) | |
def __str__(self): | |
return pformat(self.__dict__) | |
def __repr__(self): return self.__str__() | |
class Cell(obj): | |
def __init__(self, **kwargs): | |
self.is_terminal = False | |
self.policy = '' | |
self.value = 0 | |
super(Cell, self).__init__(**kwargs) # override defaults if set | |
def __str__(self): | |
return pformat(self.__dict__) | |
def __repr__(self): return self.__str__() | |
def __eq__(self, other): | |
if isinstance(other, Cell): | |
return self.__dict__ == other.__dict__ | |
else: | |
return self.__dict__ == other | |
world_bounds_rows = 3 | |
world_bounds_cols = 4 | |
class GridWorld(object): | |
def __init__(self, array): | |
cells = [] | |
for row_idx, row in enumerate(array): | |
cells.append([]) | |
for col_idx, value in enumerate(row): | |
cell = Cell() | |
cell.actual_move_to_count_map = defaultdict(lambda: 0) | |
cell.row = row_idx | |
cell.col = col_idx | |
cell.value = value | |
# TODO get reward from input | |
# If the value is nonzero, NaN, or +-Inf, set as reward too | |
cell.reward = value or 0 | |
cell.blocks = np.isnan(value) | |
cells[row_idx].append(cell) | |
assert len(cells[row_idx]) == world_bounds_cols | |
assert len(cells) == world_bounds_rows | |
cells[0][3].is_terminal = True | |
cells[1][3].is_terminal = True | |
self.cells = cells | |
self.policy_to_move_map = { | |
'n': self.north, 's': self.south, 'e': self.east, 'w': self.west | |
} | |
def cells_str(self): | |
rows = [] | |
for row in self.cells: | |
rows.append(''.join('{:4},{:5} '.format(c.value, c.policy) for c in row)) | |
return ''' | |
'''.join(rows) # newline | |
def policy(self, policy=None): | |
'''Get or set the policy of all cells. The policy arg must have the same number of rows and columns as the grid.''' | |
if not policy: | |
return [[c.policy for c in rows] for rows in self.cells] | |
assert len(policy) == len(self.cells) | |
for row,row_policy in zip(self.cells, policy): | |
assert len(row) == len(row_policy) | |
for cell,cell_policy in zip(row, row_policy): | |
cell.policy = cell_policy | |
def __iter__(self): | |
for row in self.cells: | |
for cell in row: | |
yield cell | |
def as_array(self): | |
a = [ [cell.value for cell in row] for row in self.cells ] | |
return np.array(a) | |
def start_cell(self, **kwargs): | |
if not kwargs: | |
return self._start_cell | |
row,col = kwargs['row'],kwargs['col'] | |
self._start_cell = self.cells[row][col] | |
def next_following_policy(self, cell): | |
# Map policy to NSEW | |
p = cell.policy | |
if not p: raise ValueError('Cell does not have a policy, so we must be done or you made a mistake, e.g. your policy moves into a wall. Cell was: '+str(cell)) | |
first_char = p[0].lower() | |
move = self.policy_to_move_map[first_char] | |
next_cell = move(cell) | |
return next_cell | |
def north(self, cell): | |
row = cell.row; col = cell.col | |
if row > 0: | |
row -= 1 | |
return self._next_cell_if_not_blocked(cell, row, col, 'n') | |
def south(self, cell): | |
row = cell.row; col = cell.col | |
if row < world_bounds_rows-1: | |
row += 1 | |
return self._next_cell_if_not_blocked(cell, row, col, 's') | |
def west(self, cell): | |
row = cell.row; col = cell.col | |
if col > 0: | |
col -= 1 | |
return self._next_cell_if_not_blocked(cell, row, col, 'w') | |
def east(self, cell): | |
row = cell.row; col = cell.col | |
if col < world_bounds_cols-1: | |
col += 1 | |
return self._next_cell_if_not_blocked(cell, row, col, 'e') | |
def _next_cell_if_not_blocked(self, cell, row, col, direction): | |
cell.actual_move = direction # Staying in place is implied if that's the case | |
n = self.cells[row][col] | |
if n.blocks: | |
return cell | |
return n | |
def __str__(self): | |
return pformat(self.as_array()) | |
def __repr__(self): return self.__str__() | |
class TemporalDifferenceLearningAlgo(obj): | |
def __init__(self, world, **kwargs): | |
## self.learning_factor = lambda Ns: 1 #default | |
super(TemporalDifferenceLearningAlgo, self).__init__( | |
world=world, **kwargs) | |
self.reward = 0 | |
self._init_itercnts() | |
self._init_utilmap() | |
north = self.world.north | |
south = self.world.south | |
east = self.world.east | |
west = self.world.west | |
self.policy_changed = False | |
# Mappings to pick move given random value | |
self.stochastic_map = { | |
'n': lambda r: north if r<=0.8 else east if r<=0.9 else west, | |
's': lambda r: south if r<=0.8 else east if r<=0.9 else west, | |
'e': lambda r: east if r<=0.8 else north if r<=0.9 else south, | |
'w': lambda r: west if r<=0.8 else north if r<=0.9 else south | |
} | |
def _init_itercnts(self): | |
self.iteration_map = [] | |
for row in self.world.cells: | |
self.iteration_map.append([0 for c in row]) | |
def _init_utilmap(self, and_iters=False): | |
self.utility_map = [] | |
for row in self.world.cells: | |
umap = [] | |
for idx,c in enumerate(row): | |
umap.append(c.value) # could be NaN | |
#TODO maybe use 'reward' to initialize | |
self.utility_map.append(umap) | |
def next_following_stochastic_policy(self, cell): | |
policy = cell.policy | |
if not policy: raise ValueError('The cell you gave me did not have a policy. Maybe you didnt check for a terminal. The cell was: {}'.format(cell)) | |
direction = policy[0].lower() | |
move_func = self.stochastic_map[direction](random.uniform(0,1)) | |
return move_func(cell) | |
def update(self, stochastic=False): | |
## Log.i('update()'.center(50,'-')) | |
# Start at start_cell and follow policies to update | |
cell = self.world.start_cell() | |
next_func = self.world.next_following_policy | |
if stochastic: | |
next_func = self.next_following_stochastic_policy | |
icnt=0 | |
while(cell and not cell.is_terminal): | |
icnt+=1 | |
## Log.i(' update_iteration: {}'.format(icnt)) | |
## Log.i('cell: {}'.format(cell)) | |
# Update cell's count | |
self.iteration_map[cell.row][cell.col] += 1 | |
iter_cnt = self.iteration_map[cell.row][cell.col] | |
## Log.i('Ns[cell]: {}'.format(iter_cnt-1)) | |
# Follow cell's policy to get next cell's utility | |
next_cell = next_func(cell) | |
rel_dir = self._movement_relative(cell) | |
cell.actual_move_to_count_map[rel_dir] += 1 | |
## Log.i('next_cell: {}'.format(next_cell)) | |
# Update utilities | |
utility = self.utility(cell) | |
## Log.i('U[cell]={}'.format(utility)) | |
next_utility = self.utility(next_cell) | |
## Log.i('U[next_cell]={}'.format(next_utility)) | |
rhs = self.reward + self.discount_factor*next_utility - utility | |
utility += self.learning_factor(iter_cnt) * rhs | |
## Log.i('new U[cell]={}'.format(utility)) | |
self.utility(cell, utility) | |
cell = next_cell | |
relative_directions = ['n', 'e', 's', 'w', 'n'] | |
def _movement_relative(self, cell): | |
policy = cell.policy | |
actual = cell.actual_move | |
dirs = self.relative_directions | |
diff = dirs.index(policy) - dirs.index(actual) | |
if diff == 0: | |
return 'straight' | |
if diff == -1 or diff == 3: | |
return 'right' | |
if diff == 1 or diff == -3: | |
return 'left' | |
raise ValueError('between {} and {} a diff of {} isnt valid'.format( | |
policy, actual, diff)) | |
def _left(self, cell): | |
dirs = self.relative_directions | |
return dirs[ dirs.index(cell.policy, 1) - 1 ] | |
def _right(self, cell): | |
dirs = self.relative_directions | |
return dirs[ dirs.index(cell.policy) + 1 ] | |
def _straight(self, cell): | |
return cell.policy | |
def _get_estimated_prob(self, cell, direction): | |
iter_cnt = self.iteration_map[cell.row][cell.col] | |
if iter_cnt == 0: return 0 | |
cell.actual_move = direction | |
rel_dir = self._movement_relative(cell) | |
return float(cell.actual_move_to_count_map[rel_dir]) / iter_cnt | |
def utility(self, cell, value=None): | |
if not value: | |
return self.utility_map[cell.row][cell.col] | |
self.utility_map[cell.row][cell.col] = value | |
def update_policy_greedily(self, reset_utils=False, reset_counts=False): | |
# Stupid | |
for cell in self.world: | |
cell.value = self.utility(cell) | |
for cell in self.world: | |
if cell.is_terminal or cell.blocks: continue | |
# Policies: | |
left = self._left(cell) | |
right = self._right(cell) | |
straight = cell.policy | |
pleft = self._get_estimated_prob(cell, left) | |
pright = self._get_estimated_prob(cell, right) | |
pstraight = self._get_estimated_prob(cell, straight) | |
policy = self._get_updated_policy(cell, pleft, pright, pstraight) | |
## self.policy_changed = True # HACK to print | |
if policy != cell.policy: | |
self.policy_changed = True | |
## print 'changed policy (now {}) for cell: {}'.format(policy, cell) | |
cell.policy = policy | |
if reset_utils: | |
self._init_utilmap() | |
if reset_counts: | |
self._init_itercnts() | |
def _get_updated_policy(self, state, pleft, pright, pstraight): | |
if state.is_terminal or state.blocks: return | |
max_pv = 0 # probability * value, as used in V(s) calculation | |
# Moves in NSEW order | |
moves_pvs = [ | |
pstraight*self.world.north(state).value + pleft*self.world.west(state).value + pright*self.world.east(state).value, | |
pstraight*self.world.south(state).value + pright*self.world.west(state).value + pleft*self.world.east(state).value, | |
pstraight*self.world.east(state).value + pleft*self.world.north(state).value + pright*self.world.south(state).value, | |
pstraight*self.world.west(state).value + pright*self.world.north(state).value + pleft*self.world.south(state).value | |
] | |
moves_directions = ['n', 's', 'e', 'w'] | |
max_idx = np.argmax(moves_pvs) | |
max_pv = moves_pvs[max_idx] | |
policy = moves_directions[max_idx] | |
if max_pv < 1e-3: | |
# If the max prob was ~zero, just leave the policy as-is | |
return state.policy | |
return policy | |
def print_grid_like_array(arr, min_cell_width=0, mapper=None): | |
if not mapper: mapper = lambda x: x | |
try: | |
for row in arr: | |
if min_cell_width > 0: | |
print ''.join('{:{}}'.format(mapper(c), min_cell_width) for c in row) | |
else: | |
print ''.join('{}'.format(mapper(c)) for c in row) | |
## print row | |
except AttributeError: | |
pass | |
def print_util(arr): | |
for row in arr: | |
print ''.join('{:8.3f}'.format(c) for c in row) | |
def print_actual_move_to_count_map(cells): | |
print_grid_like_array([[c.actual_move_to_count_map.items() for c in row] for row in cells]) | |
import timeit | |
def temporal_difference_learning(active=False): | |
global world, algo | |
discount_factor = 0.9 | |
learning_factor = lambda Ns: 1.0 / (Ns + 1) # alpha function | |
world_input = np.array([[0,0,0,1], [0,np.nan,0,-1], [0,0,0,0]]) | |
world = GridWorld(world_input) | |
world.start_cell(row=2,col=0) | |
world.policy([ | |
['e', 'e', 'e', '' ], | |
['s', '', 'n', '' ], | |
['e', 'e', 'n', 's'] | |
]) | |
algo = TemporalDifferenceLearningAlgo(world, discount_factor=discount_factor, learning_factor=learning_factor); | |
def run(): | |
print 'starting with policy:' | |
print_grid_like_array(world.policy(), 3) | |
def printout(): | |
print 'iteration {}'.format(iter_cnt).center(72, '-') | |
print 'utilities:' | |
print_util(algo.utility_map) | |
## print 'actual move counts:' | |
## print_actual_move_to_count_map(world.cells) | |
## print 'greedy policy values:' | |
## print_grid_like_array(world.cells, mapper=lambda c: '{:8.3f}'.format(c.greedy_policy_value)) | |
iter_cnt = 0 | |
prev_util = None | |
end_cnt = 100000 | |
update_policy_after_cnt = 100 | |
while(True): | |
if iter_cnt <= 3 or iter_cnt >= end_cnt: | |
printout() | |
if iter_cnt > end_cnt: break | |
if False and within_epsilon(prev_util, algo.utility_map, 0.0000001): | |
printout() | |
break | |
prev_util = [a[:] for a in algo.utility_map] | |
algo.update(True) | |
if active and iter_cnt % update_policy_after_cnt == 0 and iter_cnt>0: | |
prev_policy = world.policy() | |
algo.update_policy_greedily() | |
if algo.policy_changed: | |
algo.policy_changed = False | |
printout() | |
print 'previous policy:' | |
print_grid_like_array(prev_policy, 3) | |
print 'updated policy:' | |
print_grid_like_array(world.policy(), 3) | |
iter_cnt += 1 | |
print 'Ns[cells]:' | |
print_grid_like_array(algo.iteration_map, 8) | |
t = timeit.Timer(run) | |
seconds = t.timeit(1) | |
print 'Took {}s to run.'.format(seconds) | |
def active_greedy(): | |
temporal_difference_learning(active=True) | |
def within_epsilon(prev, current, epsilon=0.0001): | |
if not prev: return False # so there's no need to init prev for first check | |
for p_row,c_row in zip(prev, current): | |
for p,c, in zip(p_row,c_row): | |
if abs(p-c) > epsilon: | |
return False | |
return True | |
import unittest | |
class TemporalDifferenceLearningAlgoTests(unittest.TestCase): | |
def test_Given_lecture_grid_When_passive_applied_deterministically_Then_matches_lecture_10_10_results(self): | |
world = GridWorld(np.array([[0,0,0,1], [0,np.nan,0,-1], [0,0,0,0]])) | |
world.start_cell(row=2,col=0) | |
world.policy([ | |
['e', 'e', 'e', '' ], | |
['n', '', 'n', '' ], | |
['n', 'e', 'n', 'w'] | |
]) | |
discount_factor = 1 | |
learning_factor = lambda Ns: 1.0 / (Ns + 1) # alpha function | |
target = TemporalDifferenceLearningAlgo(world, discount_factor=discount_factor, learning_factor=learning_factor); | |
expected_utils = [a[:] for a in target.utility_map] | |
expected_utils[0][2] = 1.0/2 | |
target.update() | |
self.assertEqual(len(expected_utils), len(target.utility_map)) | |
for e,a in zip(expected_utils, target.utility_map): | |
self.assertAlmostEqual(e,a) | |
expected_utils[0][2] = 2.0/3 | |
expected_utils[0][1] = 1.0/6 | |
target.update() | |
for e,a in zip(expected_utils, target.utility_map): | |
self.assertAlmostEqual(e,a) | |
print 'Passive temporal diff learning:' | |
temporal_difference_learning() | |
print 'Active greedy:' | |
active_greedy() | |
##unittest.main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Passive temporal diff learning:
starting with policy:
e e e
s n
e e n s
------------------------------iteration 0-------------------------------
utilities:
0.000 0.000 0.000 1.000
0.000 nan 0.000 -1.000
0.000 0.000 0.000 0.000
------------------------------iteration 1-------------------------------
utilities:
0.000 0.000 0.450 1.000
0.000 nan 0.000 -1.000
0.000 0.000 0.000 0.000
------------------------------iteration 2-------------------------------
utilities:
0.000 0.000 0.450 1.000
0.000 nan -0.300 -1.000
0.000 0.000 0.000 0.000
------------------------------iteration 3-------------------------------
utilities:
0.000 0.000 0.600 1.000
0.000 nan -0.124 -1.000
0.000 0.000 -0.054 0.000
----------------------------iteration 100000----------------------------
utilities:
0.000 0.000 0.848 1.000
0.257 nan 0.572 -1.000
0.324 0.397 0.460 0.158
----------------------------iteration 100001----------------------------
utilities:
0.000 0.000 0.848 1.000
0.257 nan 0.572 -1.000
0.324 0.397 0.460 0.158
Ns[cells]:
0 0 109364 0
15797 0 123096 0
124979 140660 125044 124240
Took 19.8316876911s to run.
Active greedy:
starting with policy:
e e e
s n
e e n s
------------------------------iteration 0-------------------------------
utilities:
0.000 0.000 0.000 1.000
0.000 nan 0.000 -1.000
0.000 0.000 0.000 0.000
------------------------------iteration 1-------------------------------
utilities:
0.000 0.000 0.450 1.000
0.000 nan 0.000 -1.000
0.000 0.000 0.000 0.000
------------------------------iteration 2-------------------------------
utilities:
0.000 0.000 0.600 1.000
0.000 nan 0.135 -1.000
0.000 0.000 0.000 0.000
------------------------------iteration 3-------------------------------
utilities:
0.000 0.000 0.675 1.000
0.000 nan 0.236 -1.000
0.000 0.000 0.030 0.000
-----------------------------iteration 100------------------------------
utilities:
0.000 0.000 0.840 1.000
0.018 nan 0.484 -1.000
0.073 0.152 0.286 0.044
previous policy:
e e e
s n
e e n s
updated policy:
e e e
s n
e e n w
----------------------------iteration 100000----------------------------
utilities:
0.000 0.000 0.847 1.000
0.269 nan 0.569 -1.000
0.331 0.402 0.467 0.268
----------------------------iteration 100001----------------------------
utilities:
0.000 0.000 0.847 1.000
0.269 nan 0.569 -1.000
0.331 0.402 0.467 0.268
Ns[cells]:
0 0 108109 0
15567 0 121752 0
125065 140608 123332 13747
Took 18.1787866773s to run.