Skip to content

Instantly share code, notes, and snippets.

@kengz
Created June 22, 2018 05:07
Show Gist options
  • Save kengz/adb964d23ae3e2cad3c22e9d6e2ba37e to your computer and use it in GitHub Desktop.
Save kengz/adb964d23ae3e2cad3c22e9d6e2ba37e to your computer and use it in GitHub Desktop.
def post_body_init(self):
'''Initializes the part of algorithm needing a body to exist first.'''
self.body = self.agent.nanflat_body_a[0] # single-body algo
# create the extra replay memory for SIL
memory_name = self.memory_spec['sil_replay_name']
MemoryClass = getattr(memory, memory_name)
self.body.replay_memory = MemoryClass(self.memory_spec, self, self.body)
self.init_algorithm_params()
self.init_nets()
logger.info(util.self_desc(self))
# ...
def sample(self):
'''Modify the onpolicy sample to also append to replay'''
batches = [body.memory.sample() for body in self.agent.nanflat_body_a]
batch = util.concat_batches(batches)
data_keys = self.body.replay_memory.data_keys
for idx in range(len(batch['dones'])):
tuples = [batch[k][idx] for k in data_keys]
self.body.replay_memory.add_experience(*tuples)
batch = util.to_torch_batch(batch, self.net.gpu)
return batch
def replay_sample(self):
'''Samples a batch from memory'''
batches = [body.replay_memory.sample() for body in self.agent.nanflat_body_a]
batch = util.concat_batches(batches)
batch = util.to_torch_batch(batch, self.net.gpu)
assert not torch.isnan(batch['states']).any()
return batch
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment