Skip to content

Instantly share code, notes, and snippets.

@sherjilozair
Created February 19, 2017 12:12
Show Gist options
  • Save sherjilozair/f85d39d8416bdf2d533d99a0f08866e9 to your computer and use it in GitHub Desktop.
Save sherjilozair/f85d39d8416bdf2d533d99a0f08866e9 to your computer and use it in GitHub Desktop.
dqn impl in 150 lines
import random
import numpy as np
import gym
from scipy.misc import imresize
import tensorflow as tf
import tensorflow.contrib.slim as slim
class ReplayMemory:
def __init__(self, size=10**6, dims=(42, 42),
order=4, seqlen=1, mbsz=64):
self.size = size
self.order = order
self.seqlen = seqlen
self.dims = dims
self.mbsz = mbsz
self.count = self.current = 0
self.totlen = self.order + self.seqlen
# preallocate memory
self.actions = np.empty(self.size, dtype = np.uint8)
self.rewards = np.empty(self.size, dtype = np.integer)
self.screens = np.empty((self.size,) + self.dims, dtype = np.uint8)
self.terminals = np.empty(self.size, dtype = np.bool)
self.indexes = np.empty((self.mbsz, self.totlen), dtype = np.integer)
def add(self, action, reward, screen, terminal):
assert screen.shape == self.dims, (screen.shape, self.dims)
self.actions[self.current] = action
self.rewards[self.current] = reward
self.screens[self.current, ...] = screen
self.terminals[self.current] = terminal
self.count = max(self.count, self.current + 1)
self.current = (self.current + 1) % self.size
def sample(self):
assert self.count > self.totlen
# sample random indexes
for j in xrange(self.mbsz):
# find random index
while True:
index = random.randint(self.totlen, self.count - 1)
if index - self.totlen < self.current <= index:
continue
if self.terminals[(index - self.totlen):index].any():
continue
break
self.indexes[j, :] = (index - np.arange(self.totlen)[::-1]) % self.count
screens = self.screens[self.indexes]
states = np.array([screens[:, i:self.seqlen+1+i, ...] / 255.
for i in xrange(self.order)]).transpose([1, 2, 3, 4, 0])
actions = self.actions[self.indexes[:, -self.seqlen-1:-1]]
rewards = self.rewards[self.indexes[:, -self.seqlen-1:-1]]
terminals = self.terminals[self.indexes[:, -self.seqlen-1:-1]]
return states, actions, rewards, terminals
class Network:
def __init__(self, inp_dim, out_dim, update_freq=10):
self.inp_dim = inp_dim
self.out_dim = out_dim
self.update_freq = update_freq
self.step = tf.Variable(0, name='global_step', trainable=False)
self.inputs = tf.placeholder(tf.float32, (None,) + self.inp_dim)
self.targets = tf.placeholder(tf.float32, (None, self.out_dim))
self.outputs = self.build_network(self.inputs, scope='primary')
self.t_outputs = self.build_network(self.inputs, scope='target')
self.primary_weights = slim.get_variables(scope='primary')
self.target_weights = slim.get_variables(scope='target')
self.loss = tf.reduce_mean(tf.square(self.outputs - self.targets))
self.opt = tf.train.AdamOptimizer(1e-4)
self.train_op = self.opt.minimize(self.loss, global_step=self.step, var_list=self.primary_weights)
self.sess = tf.InteractiveSession()
self.sess.run(tf.global_variables_initializer())
self.update_target = tf.group(*[self.target_weights[i].assign(self.primary_weights[i]) for i in range(len(self.primary_weights))])
def build_network(self, inputs, scope='scope'):
with tf.variable_scope(scope):
h = slim.conv2d(inputs, 32, [8, 8], stride=[4, 4], scope='conv1')
h = slim.conv2d(h, 64, [4, 4], stride=[2, 2], scope='conv2')
h = slim.conv2d(h, 64, [3, 3], stride=[1, 1], scope='conv3')
h = slim.fully_connected(slim.flatten(h), 256, scope='dense1')
return slim.fully_connected(h, self.out_dim, None, scope='dense2')
def run(self, inputs, target=False):
outputs = self.t_outputs if target else self.outputs
leading_dims = inputs.shape[:-3]
inputs = inputs.reshape(*((-1,) + self.inp_dim))
outputs = self.sess.run(outputs, {self.inputs: inputs})
return outputs.reshape(leading_dims + (self.out_dim,)).squeeze()
def update(self, inputs, targets):
_, loss, step = self.sess.run([self.train_op, self.loss, self.step],
{self.inputs: inputs, self.targets: targets})
if step % self.update_freq == 0:
self.sess.run(self.update_target)
print loss
class Policy:
def __init__(self, dims=(42, 42), n_acts=6, order=4, gamma=0.99):
self.dims = dims
self.n_acts = n_acts
self.order = order
self.gamma = gamma
self.state = np.empty(self.dims+(self.order,))
self.critic = Network(self.dims + (self.order,), n_acts)
def act(self, screen):
self.state[..., :-1] = self.state[..., 1:]
self.state[..., -1] = screen
values = self.critic.run(self.state, target=False)
action = values.argmax()
return action
def update(self, states, actions, rewards, terminals):
lvalues = self.critic.run(states, target=False)
rvalues = self.critic.run(states, target=True)
# assert rvalues.shape[1] == 2 # only for seqlen = 1
dels = np.clip(rewards[:, 0].astype('float32'), -1., 1.)
dels += self.gamma * rvalues[:, 0, :].max(axis=1) # replace max with expec over policy
dels -= lvalues[np.arange(len(actions)), 1, actions[:, 0]]
target = lvalues[:, 0, :] # (mbsz, n_actions)
target[np.arange(len(actions)), actions] += dels
self.critic.update(states[:, 0, ...], target)
class Agent:
def __init__(self, env, memory, policy, iters=10**7, warmup=10**2):
self.iters = iters
self.warmup = warmup
self.env = env
self.memory = memory
self.policy = policy
def loop(self):
done = True
for i in xrange(self.iters):
if done:
self.env.reset()
done = False
screen = self.env.ale.getScreenGrayscale()
screen = imresize(screen.squeeze(), memory.dims)
action = self.policy.act(screen)
_, reward, done, info = self.env.step(action)
self.memory.add(action, reward, screen, done)
if i > self.warmup:
batch = memory.sample()
self.policy.update(*batch)
if __name__ == '__main__':
env = gym.make('Pong-v0')
memory = ReplayMemory()
policy = Policy()
agent = Agent(env, memory, policy)
agent.loop()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment