Last active
February 28, 2017 19:30
-
-
Save jackgoffinet/6d8b7f4876960836a6efa7534f9609e3 to your computer and use it in GitHub Desktop.
generic UCT implementation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Generic UCT implementation | |
Actions are represented by integers in [0,<num_actions>). | |
An action sequence must terminate within <seq_length> actions. | |
""" | |
__author__ = "Jack Goffinet" | |
__date__ = "December 2016 - February 2017" | |
from math import sqrt, log | |
class UctNode: | |
"""UCT node class""" | |
def __init__(self, action): | |
self.action = action | |
self.val = 0.0 | |
self.n = 0 | |
self.children = None | |
class Uct: | |
"""UCT class""" | |
def __init__(self, seq_length, num_actions, c, get_reward, gen, kwargs): | |
"""Initialize UCT. | |
Parameters | |
---------- | |
seq_length : int | |
An upper bound on the length of the action sequence. | |
num_actions : int | |
The number of possible actions. | |
c : float | |
UCT's exploration/exploitation parameter. | |
get_reward : function | |
Returns a reward given an action sequence and the best action | |
sequence found so far. Returns a float in [0,1]. | |
gen : Random | |
Random generator. | |
kwargs : dictionary | |
Additional arguments for <get_reward>. | |
""" | |
self.root = UctNode(-1) | |
self.seq_length = seq_length | |
self.num_actions = num_actions | |
self.c = c | |
self.node = None | |
self.best_reward = -1.0 | |
self.best_seq = [gen.randint(0,num_actions-1) \ | |
for i in range(seq_length)] | |
self.get_reward = get_reward | |
self.gen = gen | |
self.kwargs = kwargs | |
def play_root(self): | |
"""Play the root node once.""" | |
node_list = [self.root] | |
seq = [] | |
reward = 0.0 | |
# Traverse down the tree. | |
temp = self.root | |
flag = True | |
while flag and len(seq) <= self.seq_length: | |
temp, flag = self.sample_children(temp) | |
seq.append(temp.action) | |
node_list.append(temp) | |
# Collect a reward. | |
if len(seq) == self.seq_length and temp.n > 0: | |
reward = temp.val | |
else: | |
self.kwargs["best_seq"] = self.best_seq | |
reward, seq = self.get_reward(seq=seq, **self.kwargs) | |
if reward > self.best_reward: | |
self.best_reward = reward | |
self.best_seq = seq[:] | |
# Propagate the reward back up the tree. | |
for temp in node_list: | |
temp.n += 1 | |
temp.val += (reward - temp.val)/temp.n | |
def sample_children(self, node): | |
"""Return a biased sample of the given node's children.""" | |
# If the node has no children, choose one at random. | |
if node.children is None: | |
node.children = [None]*self.num_actions | |
action = self.gen.randint(0,self.num_actions-1) | |
node.children[action] = UctNode(action) | |
return node.children[action], False | |
# Initialize a random child, if possible. | |
temp = [] | |
for i in range(self.num_actions): | |
if node.children[i] is None: | |
temp.append(i) | |
if len(temp) > 0: | |
action = self.gen.choice(temp) | |
node.children[action] = UctNode(action) | |
return node.children[action], False | |
# Otherwise, choose a child according to its UCB1 score. | |
self.node = node | |
scores = list(map(self.ucb1, range(self.num_actions))) | |
return node.children[self.argmax(scores)], True | |
def ucb1(self, child_num): | |
"""Return the UCB1 score of self.node.children[<child_num>].""" | |
return (self.node.children[child_num].val + | |
(self.c*sqrt(log(self.node.n)/ | |
self.node.children[child_num].n))) | |
def argmax(self, iterable): | |
max_element = -1.0 | |
argmax = [] | |
for i in range(len(iterable)): | |
if iterable[i] > max_element: | |
max_element = iterable[i] | |
argmax = [i] | |
elif iterable[i] == max_element: | |
argmax.append(i) | |
return self.gen.choice(argmax) | |
if __name__=="__main__": | |
pass |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment