Last active
July 20, 2018 16:48
-
-
Save FLamparski/b3540528828aa4570c4c90fbb30e6490 to your computer and use it in GitHub Desktop.
Further adventures in reinforcement learning
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import random | |
import numpy as np | |
from collections import deque | |
class QAgent: | |
def __init__(self, model, γ=0.98, batch_size=64): | |
self.model = model | |
self.memory = deque(maxlen=10000) | |
self.γ = γ | |
self.batch_size = batch_size | |
self.metrics_log = [] | |
def get_Q(self, state): | |
xs = state.reshape(1, state.shape[0]) | |
ys = self.model.predict(xs) | |
return ys.reshape(ys.shape[1]) | |
def remember(self, state, action, reward, next_state, done): | |
self.memory.append((state, action, reward, next_state, done)) | |
def learn(self, rounds=1): | |
sample_size = min(len(self.memory), self.batch_size) | |
for _ in range(rounds): | |
xs, ys = self.experiences(sample_size) | |
self.model.train_on_batch(np.array(xs), np.array(ys)) | |
def test(self): | |
sample_size = min(len(self.memory), self.batch_size) | |
xs, ys = self.experiences(sample_size) | |
metrics = self.model.evaluate(np.array(xs), np.array(ys), verbose=0) | |
self.metrics_log.append(metrics) | |
return metrics | |
def experiences(self, sample_size): | |
sample = random.sample(self.memory, sample_size) | |
xs, ys = [], [] | |
for state, action, reward, next_state, done in sample: | |
Qrow = self.get_Q(state) | |
Qnext = self.get_Q(next_state) | |
Qrow[action] = reward if done else reward + self.γ * np.max(Qnext) | |
xs.append(state) | |
ys.append(Qrow) | |
return xs, ys |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
See also:
log_progress
Problem: how can I avoid the model accidentally converging on just picking one action? How can I make it actually learn the objective?