This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| def act(self, state, decay_step): | |
| # EPSILON GREEDY STRATEGY | |
| if self.epsilon_greedy: | |
| # Here we'll use an improved version of our epsilon greedy strategy for Q-learning | |
| explore_probability = self.epsilon_min + (self.epsilon - self.epsilon_min) * np.exp(-self.epsilon_decay * decay_step) | |
| # OLD EPSILON STRATEGY | |
| else: | |
| if self.epsilon > self.epsilon_min: | |
| self.epsilon *= (1-self.epsilon_decay) | |
| explore_probability = self.epsilon |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| class SumTree(object): | |
| data_pointer = 0 | |
| # Here we initialize the tree with all nodes = 0, and initialize the data with all values = 0 | |
| def __init__(self, capacity): | |
| # Number of leaf nodes (final nodes) that contains experiences | |
| self.capacity = capacity | |
| # Generate the tree with all nodes values = 0 | |
| # To understand this calculation (2 * capacity - 1) look at the schema below |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| def add(self, priority, data): | |
| # Look at what index we want to put the experience | |
| tree_index = self.data_pointer + self.capacity - 1 | |
| """ tree: | |
| 0 | |
| / \ | |
| 0 0 | |
| / \ / \ | |
| tree_index 0 0 0 We fill the leaves from left to right |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| def update(self, tree_index, priority): | |
| # Change = new priority score - former priority score | |
| change = priority - self.tree[tree_index] | |
| self.tree[tree_index] = priority | |
| # then propagate the change through tree | |
| # this method is faster than the recursive loop | |
| while tree_index != 0: | |
| tree_index = (tree_index - 1) // 2 | |
| self.tree[tree_index] += change |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| def get_leaf(self, v): | |
| parent_index = 0 | |
| while True: | |
| left_child_index = 2 * parent_index + 1 | |
| right_child_index = left_child_index + 1 | |
| # If we reach bottom, end the search | |
| if left_child_index >= len(self.tree): | |
| leaf_index = parent_index |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| class Memory(object): # stored as ( state, action, reward, next_state ) in SumTree | |
| PER_e = 0.01 # Hyperparameter that we use to avoid some experiences to have 0 probability of being taken | |
| PER_a = 0.6 # Hyperparameter that we use to make a tradeoff between taking only exp with high priority and sampling randomly | |
| PER_b = 0.4 # importance-sampling, from initial value increasing to 1 | |
| PER_b_increment_per_sampling = 0.001 | |
| absolute_error_upper = 1. # clipped abs error | |
| def __init__(self, capacity): |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| def store(self, experience): | |
| # Find the max priority | |
| max_priority = np.max(self.tree.tree[-self.tree.capacity:]) | |
| # If the max priority = 0 we can't put priority = 0 since this experience will never have a chance to be selected | |
| # So we use a minimum priority | |
| if max_priority == 0: | |
| max_priority = self.absolute_error_upper | |
| self.tree.add(max_priority, experience) # set the max priority for new priority |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| def sample(self, n): | |
| # Create a minibatch array that will contains the minibatch | |
| minibatch = [] | |
| b_idx = np.empty((n,), dtype=np.int32) | |
| # Calculate the priority segment | |
| # Here, as explained in the paper, we divide the Range[0, ptotal] into n ranges | |
| priority_segment = self.tree.total_priority / n # priority segment |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| def batch_update(self, tree_idx, abs_errors): | |
| abs_errors += self.PER_e # convert to abs and avoid 0 | |
| clipped_errors = np.minimum(abs_errors, self.absolute_error_upper) | |
| ps = np.power(clipped_errors, self.PER_a) | |
| for ti, p in zip(tree_idx, ps): | |
| self.tree.update(ti, p) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| def replay(self): | |
| if self.USE_PER: | |
| # Sample minibatch from the PER memory | |
| tree_idx, minibatch = self.MEMORY.sample(self.batch_size) | |
| else: | |
| # Randomly sample minibatch from the deque memory | |
| minibatch = random.sample(self.memory, min(len(self.memory), self.batch_size)) | |
| ''' | |
| everything stay the same here as before | |
| ''' |