Last active
August 29, 2015 13:56
-
-
Save vo/9045230 to your computer and use it in GitHub Desktop.
Q Learner for Nim, in Python
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import random | |
num_sticks = 22 | |
num_states = num_sticks + 6 | |
num_actions = 3 | |
action_list = range(num_actions) | |
num_iterations = 10000 | |
gamma = 0.1 | |
alpha = 0.5 | |
def best_action(q, state, default=None): | |
'''Return the action with highest q value for state. | |
ties are broken at random unless default is specified''' | |
actions = q[state] | |
best_q_value = max(actions) | |
indices = [i for i, x in enumerate(actions) if x == best_q_value] | |
if default: | |
return default | |
else: | |
return random.choice(indices) | |
def print_policy(q): | |
'''Print the optimal policy''' | |
for i in range(num_sticks+1): | |
action = best_action(q, i, '-') | |
if action is not '-': | |
action += 1 | |
print "Sticks left:",(num_sticks-i),"; Take:",action | |
def main(): | |
# create a q-table (e.g. Q[s,a]) | |
Q = [[0.0 for x in action_list] for x in xrange(num_states)] | |
for i in xrange(num_iterations): | |
# reset game state | |
state = 0 | |
# play a game until all the sticks are gone | |
while state < num_sticks: | |
old_state = state | |
# carry out my action | |
# use an annealing approach | |
if random.randint(0,num_iterations-1) > random.randint(0,i): | |
my_action = random.randint(0,num_actions-1) | |
else: | |
my_action = best_action(Q, state) | |
state += my_action + 1 | |
if state >= num_sticks: | |
reward = -1.0 # we lost (boo) | |
else: | |
# carry out opponent action | |
opponent_action = best_action(Q, state) | |
state += opponent_action + 1 | |
if state >= num_sticks: | |
reward = 1.0 # opponent lost (yay!) | |
else: | |
reward = 0.0 # nobody lost (meh) | |
# update q table using this update rule thingy | |
# q'(s,a) = (1 - alpha) * q(s,a) + alpha * (r + gamma ( max a' Q(s',a') )) | |
Q[old_state][my_action] = (1.0 - alpha) * Q[old_state][my_action] + alpha * (reward + gamma * Q[state][best_action(Q,state)]) | |
if state > num_sticks: | |
break | |
print_policy(Q) | |
if __name__ == '__main__': | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment