Created
February 23, 2022 14:14
-
-
Save thunderInfy/eb987f297e87d8f9ef45a22ec3510625 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class Node: | |
def __init__(self, state, model): | |
# saves state as a dictionary | |
self.state = state | |
# needs access to the neural network model | |
self.model = model | |
# W is the total reward and N is the number of playouts | |
self.W = 0 | |
self.N = 0 | |
self.value = None | |
self.policy = None | |
# sets which actions are valid and which are invalid | |
# in the variables self.valid_actions and | |
# self.invalid_actions respectively | |
self.set_action_validity() | |
# for all valid_actions, initialize new nodes (but don't | |
# fill them yet with states, i.e., lazy initialization) | |
self.initialize_edges() | |
# None if it's not a terminal state, otherwise 'red' or 'green' | |
# indicating the winner of that terminal state | |
self.win = None | |
def initialize_edges(self): | |
if self.state is not None: | |
# a dictionary with action tuples as keys | |
# and nodes as values | |
self.children = {} | |
for row in range(args.M): | |
for col in range(args.N): | |
if self.valid_actions[row][col]: | |
self.children[(row,col)] = Node(None, self.model) | |
def set_action_validity(self): | |
# what's an invalid action? | |
# a player can click anywhere on the board | |
# except those cells where orbs from the | |
# opposite player reside | |
if self.state is not None: | |
if self.state['player_turn'] == 'red': | |
self.invalid_actions = self.state['array_view']<0 | |
else: | |
self.invalid_actions = self.state['array_view']>0 | |
self.valid_actions = ~self.invalid_actions | |
def make_forward_pass(self): | |
# this function is useful to get | |
# policy and value for the current node | |
# both are used with MCTS | |
# policy is used for tree traversal | |
# and value is used as an alternative for monte carlo rollouts | |
with torch.no_grad(): | |
out = self.model( | |
self.model.state_array_view_to_tensor(self.state) | |
) | |
self.policy = out['policy'][0].cpu().numpy() | |
self.value = out['value'].cpu().item() | |
self.policy[self.invalid_actions] = 0 | |
# handling rare case where sum becomes zero | |
# can happen because of treating low magnitude values as zero | |
if self.policy.sum()==0: | |
self.policy[self.valid_actions] = 1 | |
self.policy /= self.policy.sum() | |
def get_policy(self): | |
if self.policy is None: | |
self.make_forward_pass() | |
return self.policy | |
def get_value(self): | |
if self.value is None: | |
self.make_forward_pass() | |
return self.value |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment