Skip to content

Instantly share code, notes, and snippets.

View pythonlessons's full-sized avatar
🏠
Working from home

Rokas Liuberskis pythonlessons

🏠
Working from home
View GitHub Profile
@pythonlessons
pythonlessons / act_function.py
Created January 14, 2020 20:39
04_CartPole-reinforcement-learning_e_greedy_D3QN
def act(self, state, decay_step):
# EPSILON GREEDY STRATEGY
if self.epsilon_greedy:
# Here we'll use an improved version of our epsilon greedy strategy for Q-learning
explore_probability = self.epsilon_min + (self.epsilon - self.epsilon_min) * np.exp(-self.epsilon_decay * decay_step)
# OLD EPSILON STRATEGY
else:
if self.epsilon > self.epsilon_min:
self.epsilon *= (1-self.epsilon_decay)
explore_probability = self.epsilon
@pythonlessons
pythonlessons / SumTree___init__.py
Last active January 15, 2020 06:44
05_CartPole-reinforcement-learning_PER_D3QN
class SumTree(object):
data_pointer = 0
# Here we initialize the tree with all nodes = 0, and initialize the data with all values = 0
def __init__(self, capacity):
# Number of leaf nodes (final nodes) that contains experiences
self.capacity = capacity
# Generate the tree with all nodes values = 0
# To understand this calculation (2 * capacity - 1) look at the schema below
@pythonlessons
pythonlessons / SumTree_add.py
Last active January 15, 2020 06:44
05_CartPole-reinforcement-learning_PER_D3QN
def add(self, priority, data):
# Look at what index we want to put the experience
tree_index = self.data_pointer + self.capacity - 1
""" tree:
0
/ \
0 0
/ \ / \
tree_index 0 0 0 We fill the leaves from left to right
@pythonlessons
pythonlessons / SumTree_update.py
Last active January 15, 2020 06:44
05_CartPole-reinforcement-learning_PER_D3QN
def update(self, tree_index, priority):
# Change = new priority score - former priority score
change = priority - self.tree[tree_index]
self.tree[tree_index] = priority
# then propagate the change through tree
# this method is faster than the recursive loop
while tree_index != 0:
tree_index = (tree_index - 1) // 2
self.tree[tree_index] += change
@pythonlessons
pythonlessons / SumTree_get_leaf.py
Created January 15, 2020 06:43
05_CartPole-reinforcement-learning_PER_D3QN
def get_leaf(self, v):
parent_index = 0
while True:
left_child_index = 2 * parent_index + 1
right_child_index = left_child_index + 1
# If we reach bottom, end the search
if left_child_index >= len(self.tree):
leaf_index = parent_index
@pythonlessons
pythonlessons / SumTree_Memory.py
Created January 15, 2020 07:56
05_CartPole-reinforcement-learning_PER_D3QN
class Memory(object): # stored as ( state, action, reward, next_state ) in SumTree
PER_e = 0.01 # Hyperparameter that we use to avoid some experiences to have 0 probability of being taken
PER_a = 0.6 # Hyperparameter that we use to make a tradeoff between taking only exp with high priority and sampling randomly
PER_b = 0.4 # importance-sampling, from initial value increasing to 1
PER_b_increment_per_sampling = 0.001
absolute_error_upper = 1. # clipped abs error
def __init__(self, capacity):
@pythonlessons
pythonlessons / SumTree_store.py
Created January 15, 2020 08:17
05_CartPole-reinforcement-learning_PER_D3QN
def store(self, experience):
# Find the max priority
max_priority = np.max(self.tree.tree[-self.tree.capacity:])
# If the max priority = 0 we can't put priority = 0 since this experience will never have a chance to be selected
# So we use a minimum priority
if max_priority == 0:
max_priority = self.absolute_error_upper
self.tree.add(max_priority, experience) # set the max priority for new priority
@pythonlessons
pythonlessons / SumTree_sample.py
Last active January 15, 2020 08:19
05_CartPole-reinforcement-learning_PER_D3QN
def sample(self, n):
# Create a minibatch array that will contains the minibatch
minibatch = []
b_idx = np.empty((n,), dtype=np.int32)
# Calculate the priority segment
# Here, as explained in the paper, we divide the Range[0, ptotal] into n ranges
priority_segment = self.tree.total_priority / n # priority segment
@pythonlessons
pythonlessons / SumTree_batch_update.py
Created January 15, 2020 08:19
05_CartPole-reinforcement-learning_PER_D3QN
def batch_update(self, tree_idx, abs_errors):
abs_errors += self.PER_e # convert to abs and avoid 0
clipped_errors = np.minimum(abs_errors, self.absolute_error_upper)
ps = np.power(clipped_errors, self.PER_a)
for ti, p in zip(tree_idx, ps):
self.tree.update(ti, p)
@pythonlessons
pythonlessons / replay.py
Created January 15, 2020 08:45
05_CartPole-reinforcement-learning_PER_D3QN
def replay(self):
if self.USE_PER:
# Sample minibatch from the PER memory
tree_idx, minibatch = self.MEMORY.sample(self.batch_size)
else:
# Randomly sample minibatch from the deque memory
minibatch = random.sample(self.memory, min(len(self.memory), self.batch_size))
'''
everything stay the same here as before
'''