Last active
October 8, 2016 18:24
-
-
Save chetandhembre/7fc7d6d24f22f98a9db1ab4d2e8128dc to your computer and use it in GitHub Desktop.
Easy21
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!#/usr/bin/python2 | |
#plot link: https://dl.dropboxusercontent.com/u/47591917/easy21_mc.png | |
import numpy as np | |
import matplotlib | |
from matplotlib import pyplot as plt | |
from mpl_toolkits.mplot3d import Axes3D | |
RED = 0 | |
BLACK = 1 | |
STICK = 0 | |
HIT = 1 | |
WIN = 1 | |
BUST = -1 | |
DRAW = 0 | |
GREDDY = 0 | |
EXPLORE = 1 | |
def plot_surface(X, Y, Z, title): | |
fig = plt.figure(figsize=(20, 10)) | |
ax = fig.add_subplot(111, projection='3d') | |
surf = ax.plot_surface(X, Y, Z, rstride=1, cstride=1, | |
cmap=matplotlib.cm.coolwarm, vmin=-1.0, vmax=1.0) | |
ax.set_xlabel('Player Sum') | |
ax.set_ylabel('Dealer Showing') | |
ax.set_zlabel('Value') | |
ax.set_title(title) | |
ax.view_init(ax.elev, -120) | |
fig.colorbar(surf) | |
plt.savefig('easy21_mc.png') | |
plt.show() | |
class State(object): | |
# no ace in easy21 game to ace_usable will always be False | |
def __init__(self, players_sum, dealer_card): | |
self.player_sum = players_sum | |
self.dealer_sum = dealer_card | |
def __eq__(self, other): | |
return (self.player_sum, self.dealer_sum) == (other.player_sum, other.dealer_sum) | |
def __hash__(self): | |
return hash((self.player_sum, self.dealer_sum)) | |
class StateAction(object): | |
def __init__(self, player_sum, dealer_sum, action): | |
self.state = State(player_sum, dealer_sum) | |
self.action = action | |
def __hash__(self): | |
return hash((self.state, self.action)) | |
def __eq__(self, other): | |
return (other.state == self.state) and (other.action == self.action) | |
class Card(object): | |
def __init__(self, number, color): | |
self.number = number | |
self.color = color | |
class ValueMap(object): | |
def __init__(self, N0=100): | |
self.states = {} | |
self.visited_count = {} | |
self.state_visited_count = {} | |
self.N0 = N0 | |
self.initialize() | |
def initialize(self): | |
for player_sum in range(1, 21 + 1): | |
for card in range(1, 10 + 1): | |
for action in [HIT, STICK]: | |
state = StateAction(player_sum, card, action) | |
self.states[state] = 0 | |
self.visited_count[state] = 0 | |
def state_visited(self, state): | |
self.state_visited_count[state] = self.state_visited_count.get(state, 0) + 1 | |
def greedy_or_explore(self, state, action): | |
epsilon = self.N0 / float(self.N0 + self.state_visited_count.get(state, 0)) | |
actions = np.ones(2) * epsilon / 2 | |
actions[action] = actions[action] + (1 - epsilon) | |
return np.random.choice(len(actions), p=actions) | |
def select_action_state(self, state): | |
hit_action = self.states[StateAction(state.player_sum, state.dealer_sum, HIT)] | |
stick_action = self.states[StateAction(state.player_sum, state.dealer_sum, STICK)] | |
action = HIT if hit_action > stick_action else STICK | |
return self.greedy_or_explore(state, action) | |
def select_greedy_action(self, state): | |
hit_action = self.states[StateAction(state.player_sum, state.dealer_sum, HIT)] | |
stick_action = self.states[StateAction(state.player_sum, state.dealer_sum, STICK)] | |
return HIT if hit_action > stick_action else STICK | |
def select_card(): | |
color = np.random.choice([RED, BLACK], 1, p=[1 / float(3), 2 / float(3)]) | |
number = int(np.random.uniform(1, 11)) | |
return Card(number, color) | |
class Easy21Env(object): | |
def __init__(self): | |
self.player_sum = select_card().number | |
self.dealer_sum = select_card().number | |
def get_reward(self): | |
if self.player_sum > 21 or self.player_sum < 1: | |
return BUST | |
if self.player_sum < self.dealer_sum and self.dealer_sum < 22: | |
return BUST | |
return WIN if self.dealer_sum > 21 or self.player_sum > self.dealer_sum else DRAW | |
def take_player_hit(self): | |
card = select_card() | |
self.player_sum = self.player_sum + card.number | |
def dealer_move(self): | |
while True: | |
if self.dealer_sum > self.player_sum or self.dealer_sum > 21: | |
break | |
card = select_card() | |
self.dealer_sum = self.dealer_sum + card.number | |
def get_state(self): | |
return State(self.player_sum, self.dealer_sum) | |
def select_move(player_sum): | |
return STICK if player_sum > 20 else HIT | |
class Game(object): | |
def __init__(self, no_episodes): | |
self.no_episodes = no_episodes | |
self.value_map = ValueMap() | |
def _handle_value_map(self, reward, state_visited_order, states_visited): | |
for action_state in state_visited_order: | |
n = states_visited[action_state] | |
self.value_map.visited_count[action_state] = self.value_map.visited_count[action_state] + 1 | |
self.value_map.states[action_state] = self.value_map.states[action_state] + ((reward - self.value_map.states[action_state]) / float(self.value_map.visited_count[action_state])) | |
def plot(self): | |
player_sum = [] | |
dealer_sum = [] | |
result = [] | |
policy_action = [] | |
policy_a = '' | |
for player in range(12, 21 + 1): | |
line = [] | |
for card in range(1, 10 + 1): | |
state = State(player, card) | |
player_sum.append(player) | |
dealer_sum.append(card) | |
action = self.value_map.select_greedy_action(State(player, card)) | |
action_state = StateAction(player, card, action) | |
result.append(self.value_map.states[action_state]) | |
policy_action.append(action) | |
line.append(str(action)) | |
policy_a = ''.join(line) + '\n' + policy_a | |
print policy_a | |
min_x = min(k for k in dealer_sum) | |
max_x = max(k for k in dealer_sum) | |
min_y = min(k for k in player_sum) | |
max_y = max(k for k in player_sum) | |
x_range = np.arange(min_x, max_x + 1) | |
y_range = np.arange(min_y, max_y + 1) | |
X, Y = np.meshgrid(x_range, y_range) | |
# Find value for all (x, y) coordinates | |
Z_noace = np.apply_along_axis(lambda _: self.value_map.states[StateAction(_[0], _[1], self.value_map.select_greedy_action(State(_[0], _[1])))], 2, np.dstack([Y, X])) | |
# Z_ace = np.apply_along_axis(lambda _: V[(_[0], _[1], True)], 2, np.dstack([X, Y])) | |
title = "easy21" | |
plot_surface(Y, X, Z_noace, "{} (No Usable Ace)".format(title)) | |
def play(self): | |
for i in range(self.no_episodes): | |
eps = Easy21Env() | |
is_busted = False | |
states_visited = {} | |
state_visited_order = [] | |
state = eps.get_state() | |
while True: | |
action = self.value_map.select_action_state(state) | |
action_state = StateAction(state.player_sum, state.dealer_sum, action) | |
states_visited[action_state] = states_visited.get(action_state, 0) + 1 | |
state_visited_order.append(action_state) | |
if action == STICK: | |
break | |
eps.take_player_hit() | |
if eps.player_sum > 21 or eps.player_sum < 1: | |
is_busted = True | |
break | |
state = eps.get_state() | |
self.value_map.state_visited(state) | |
if is_busted: | |
self._handle_value_map(BUST, reversed(state_visited_order), states_visited) | |
else: | |
eps.dealer_move() | |
reward = eps.get_reward() | |
self._handle_value_map(reward, reversed(state_visited_order), states_visited) | |
game = Game(500000) | |
game.play() | |
game.plot() | |
""" | |
0000000000 | |
0000000000 | |
0000000000 | |
0000000000 | |
0000000000 | |
1111101111 | |
1111111111 | |
1111111111 | |
1111111111 | |
1111111111 | |
""" | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment