Skip to content

Instantly share code, notes, and snippets.

@jackgoffinet
Last active February 28, 2017 19:30
Show Gist options
  • Save jackgoffinet/6d8b7f4876960836a6efa7534f9609e3 to your computer and use it in GitHub Desktop.
Save jackgoffinet/6d8b7f4876960836a6efa7534f9609e3 to your computer and use it in GitHub Desktop.
generic UCT implementation
"""
Generic UCT implementation
Actions are represented by integers in [0,<num_actions>).
An action sequence must terminate within <seq_length> actions.
"""
__author__ = "Jack Goffinet"
__date__ = "December 2016 - February 2017"
from math import sqrt, log
class UctNode:
"""UCT node class"""
def __init__(self, action):
self.action = action
self.val = 0.0
self.n = 0
self.children = None
class Uct:
"""UCT class"""
def __init__(self, seq_length, num_actions, c, get_reward, gen, kwargs):
"""Initialize UCT.
Parameters
----------
seq_length : int
An upper bound on the length of the action sequence.
num_actions : int
The number of possible actions.
c : float
UCT's exploration/exploitation parameter.
get_reward : function
Returns a reward given an action sequence and the best action
sequence found so far. Returns a float in [0,1].
gen : Random
Random generator.
kwargs : dictionary
Additional arguments for <get_reward>.
"""
self.root = UctNode(-1)
self.seq_length = seq_length
self.num_actions = num_actions
self.c = c
self.node = None
self.best_reward = -1.0
self.best_seq = [gen.randint(0,num_actions-1) \
for i in range(seq_length)]
self.get_reward = get_reward
self.gen = gen
self.kwargs = kwargs
def play_root(self):
"""Play the root node once."""
node_list = [self.root]
seq = []
reward = 0.0
# Traverse down the tree.
temp = self.root
flag = True
while flag and len(seq) <= self.seq_length:
temp, flag = self.sample_children(temp)
seq.append(temp.action)
node_list.append(temp)
# Collect a reward.
if len(seq) == self.seq_length and temp.n > 0:
reward = temp.val
else:
self.kwargs["best_seq"] = self.best_seq
reward, seq = self.get_reward(seq=seq, **self.kwargs)
if reward > self.best_reward:
self.best_reward = reward
self.best_seq = seq[:]
# Propagate the reward back up the tree.
for temp in node_list:
temp.n += 1
temp.val += (reward - temp.val)/temp.n
def sample_children(self, node):
"""Return a biased sample of the given node's children."""
# If the node has no children, choose one at random.
if node.children is None:
node.children = [None]*self.num_actions
action = self.gen.randint(0,self.num_actions-1)
node.children[action] = UctNode(action)
return node.children[action], False
# Initialize a random child, if possible.
temp = []
for i in range(self.num_actions):
if node.children[i] is None:
temp.append(i)
if len(temp) > 0:
action = self.gen.choice(temp)
node.children[action] = UctNode(action)
return node.children[action], False
# Otherwise, choose a child according to its UCB1 score.
self.node = node
scores = list(map(self.ucb1, range(self.num_actions)))
return node.children[self.argmax(scores)], True
def ucb1(self, child_num):
"""Return the UCB1 score of self.node.children[<child_num>]."""
return (self.node.children[child_num].val +
(self.c*sqrt(log(self.node.n)/
self.node.children[child_num].n)))
def argmax(self, iterable):
max_element = -1.0
argmax = []
for i in range(len(iterable)):
if iterable[i] > max_element:
max_element = iterable[i]
argmax = [i]
elif iterable[i] == max_element:
argmax.append(i)
return self.gen.choice(argmax)
if __name__=="__main__":
pass
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment