Created
February 18, 2018 12:52
-
-
Save breeko/c60e543274e139390d0a95818ba579e0 to your computer and use it in GitHub Desktop.
Experience replay class for reinforcement learning
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
class ExperienceReplay: | |
def __init__(self,buffer_size=50000): | |
""" Data structure used to hold game experiences """ | |
# Buffer will contain [state,action,reward,next_state,done] | |
self.buffer = [] | |
self.buffer_size = buffer_size | |
def add(self, experience): | |
""" Adds list of experiences to the buffer """ | |
# Extend the stored experiences | |
self.buffer.extend(experience) | |
# Keep the last buffer_size number of experiences | |
self.buffer = self.buffer[-self.buffer_size:] | |
def sample(self, size): | |
""" Returns a sample of experiences from the buffer """ | |
sample_idxs = np.random.randint(len(self.buffer),size=size) | |
sample_output = [self.buffer[idx] for idx in sample_idxs] | |
sample_output = np.reshape(sample_output,(size,-1)) | |
return sample_output |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment