Skip to content

Instantly share code, notes, and snippets.

@alexlimh
Created April 29, 2020 13:58
Show Gist options
  • Save alexlimh/bc3d64dcb12aa8c8a37d3f3ece6e7f88 to your computer and use it in GitHub Desktop.
Save alexlimh/bc3d64dcb12aa8c8a37d3f3ece6e7f88 to your computer and use it in GitHub Desktop.
dreamer.py
import argparse
import collections
import functools
import gc
import os
import pathlib
import resource
import sys
import warnings
warnings.filterwarnings('ignore', '.*box bound precision lowered.*')
# os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
os.environ['MUJOCO_GL'] = 'egl'
import numpy as np
import tensorflow as tf
from tensorflow.keras.mixed_precision import experimental as prec
# tf.get_logger().setLevel('ERROR')
from tensorflow_probability import distributions as tfd
sys.path.append(str(pathlib.Path(__file__).parent))
import models
import tools
import wrappers
def define_config():
config = tools.AttrDict()
# General.
config.load_config = False
config.ckpt = True
config.logdir = pathlib.Path('.')
config.ckptdir = pathlib.Path('.')
config.traindir = None
config.evaldir = None
config.seed = 0
config.steps = 1e7
config.eval_every = 1e4
config.log_every = 1e4
config.backup_every = 1e3
config.gpu_growth = True
config.precision = 16
config.debug = False
config.expl_gifs = False
config.compress_dataset = False
config.gc = False
config.deleps = False
# Environment.
config.task = 'dmc_walker_walk'
config.size = (64, 64)
config.parallel = 'none'
config.envs = 1
config.action_repeat = 2
config.time_limit = 1000
config.prefill = 5000
config.eval_noise = 0.0
config.clip_rewards = 'identity'
config.atari_grayscale = False
config.atari_lifes = False
config.atari_fire_start = False
config.dmlab_path = None
# Model.
config.dynamics = 'rssm'
config.grad_heads = ('image', 'reward')
config.deter_size = 200
config.stoch_size = 50
config.stoch_depth = 10
config.dynamics_deep = False
config.units = 400
config.reward_layers = 2
config.discount_layers = 3
config.value_layers = 3
config.actor_layers = 4
config.act = 'elu'
config.cnn_depth = 32
config.encoder_kernels = (4, 4, 4, 4)
config.decoder_kernels = (5, 5, 6, 6)
config.decoder_thin = True
config.cnn_embed_size = 0
config.kl_scale = 1.0
config.free_nats = 3.0
config.kl_balance = 0.8
config.pred_discount = False
config.discount_rare_bonus = 0.0
config.discount_scale = 10.0
config.reward_scale = 2.0
config.weight_decay = 0.0
config.proprio = False
# Training.
config.batch_size = 50
config.batch_length = 50
config.train_every = 500
config.train_steps = 100
config.pretrain = 100
config.model_lr = 6e-4
config.value_lr = 8e-5
config.actor_lr = 8e-5
config.grad_clip = 100.0
config.opt_eps = 1e-5
config.value_grad_clip = 100.0
config.actor_grad_clip = 100.0
config.dataset_size = 0
config.dropout_embed = 0.0
config.dropout_feat = 0.0
config.dropout_imag = 0.0
config.oversample_ends = False
config.slow_value_target = False
config.slow_actor_target = False
config.slow_target_update = 0
config.slow_target_soft = False
config.slow_target_moving_weight = 0.1
# Behavior.
config.discount = 0.99
config.discount_lambda = 0.95
config.imag_horizon = 15
config.actor_dist = 'tanh_normal'
config.actor_disc = 5
config.actor_entropy = 0.0
config.actor_init_std = 5.0
config.expl = 'additive_gaussian'
config.expl_amount = 0.3
config.eval_state_mean = False
config.reuse_task_behavior = False
config.task_expl_reward_weights = 0.5
# Exploration.
config.expl_behavior = 'greedy'
config.expl_until = 0
config.disag_target = 'stoch' # embed
config.disag_scale = 5.0
config.disag_log = True
config.disag_samples = 10
config.disag_models = 10
config.disag_offset = 1
config.disag_layers = 4
config.vim_layers = 4
config.disag_units = 400
config.disag_project = 0 # 100
config.disag_noise_scale = 10.0
config.disag_save_memory = True
config.empow_scale = 1.0
config.empow_entropy = 'auto'
config.empow_marginal = 'sequences'
config.empow_sequences = 5 # 3
config.empow_samples = 5 # 3
config.temperature = 1.0
config.empow_horizon = 1
config.reuse_actor = False
return config
class Dreamer(tools.Module):
def __init__(self, config, logger, dataset):
self._config = config
self._logger = logger
self._float = prec.global_policy().compute_dtype
self._should_log = tools.Every(config.log_every)
self._should_train = tools.Every(config.train_every)
self._should_pretrain = tools.Once()
self._should_expl = tools.Until(int(config.expl_until / config.action_repeat))
self._metrics = collections.defaultdict(tf.metrics.Mean)
with tf.device('cpu:0'):
self._step = tf.Variable(count_steps(config.traindir), dtype=tf.int64)
# self._strategy = tf.distribute.MirroredStrategy()
# with self._strategy.scope():
# self._dataset = self._strategy.experimental_distribute_dataset(dataset)
self._dataset = iter(dataset)
self._world_model = models.WorldModel(config)
self._behavior = models.ImagBehavior(config, self._world_model)
if not config.reuse_task_behavior:
behavior = models.ImagBehavior(config, self._world_model)
else:
behavior = self._behavior
self._expl_behavior = dict(
greedy=lambda: self._behavior,
random=lambda: models.RandomBehavior(config),
disag=lambda: models.DisagBehavior(config, self._world_model, behavior),
disag_info=lambda: models.DisagInfoBehavior(config, self._world_model, behavior),
disag_noise=lambda: models.DisagNoiseBehavior(config, self._world_model, behavior),
empow_open=lambda: models.EmpowOpenBehavior(config, self._world_model, behavior),
empow_action=lambda: models.EmpowActionBehavior(config, self._world_model, behavior),
empow_state=lambda: models.EmpowStateBehavior(config, self._world_model, behavior),
empow_step_state=lambda: models.EmpowStepStateBehavior(config, self._world_model, behavior),
empow_step_action=lambda: models.EmpowStepActionBehavior(config, self._world_model, behavior),
empow_vim=lambda: models.EmpowVIMBehavior(config, self._world_model, behavior),
)[config.expl_behavior]()
# Train step to initialize variables including optimizer statistics.
self.train(next(self._dataset))
def __call__(self, obs, reset, state=None, training=True):
step = self._step.numpy().item()
if state is not None and reset.any():
mask = tf.cast(1 - reset, self._float)[:, None]
state = tf.nest.map_structure(lambda x: x * mask, state)
if training and self._should_train(step):
# with self._strategy.scope():
steps = (
self._config.pretrain if self._should_pretrain()
else self._config.train_steps)
for _ in range(steps):
self.train(next(self._dataset))
if self._should_log(step):
for name, mean in self._metrics.items():
self._logger.scalar(name, float(mean.result()))
mean.reset_states()
openl = self._world_model.video_pred(next(self._dataset))
self._logger.video('train_openl', openl)
self._logger.write(fps=True)
action, state = self._policy(obs, state, training)
if training:
self._step.assign_add(len(reset))
self._logger.step = self._config.action_repeat * self._step.numpy().item()
return action, state
@tf.function
def _policy(self, obs, state, training):
if state is None:
batch_size = len(obs['image'])
latent = self._world_model.dynamics.initial(len(obs['image']))
action = tf.zeros((batch_size, self._config.num_actions), self._float)
else:
latent, action = state
embed = self._world_model.encoder(self._world_model.preprocess(obs))
latent, _ = self._world_model.dynamics.obs_step(latent, action, embed)
if self._config.eval_state_mean:
latent['stoch'] = latent['mean']
feat = self._world_model.dynamics.get_feat(latent)
if not training:
action = self._behavior.actor(feat).mode()
elif self._should_expl(self._step):
action = self._expl_behavior.actor(feat).sample()
else:
action = self._behavior.actor(feat).sample()
action = self._exploration(action, training)
state = (latent, action)
return action, state
def _exploration(self, action, training):
amount = self._config.expl_amount if training else self._config.eval_noise
if amount == 0:
return action
amount = tf.cast(amount, self._float)
if self._config.expl == 'additive_gaussian':
return tf.clip_by_value(tfd.Normal(action, amount).sample(), -1, 1)
if self._config.expl == 'epsilon_greedy':
probs = amount / self._config.num_actions + (1 - amount) * action
return tools.OneHotDist(probs=probs).sample()
raise NotImplementedError(self._config.expl)
# @tf.function
# def train(self, data):
# self._strategy.experimental_run_v2(self._train, (data,))
@tf.function
def train(self, data):
metrics = {}
embed, post, feat, mets = self._world_model.train(data)
metrics.update(mets)
reward = lambda f, s, a: self._world_model.heads['reward'](f).mode()
if not self._config.reuse_task_behavior:
metrics.update(self._behavior.train(post, reward))
if self._config.expl_behavior != 'greedy':
if self._config.reuse_task_behavior:
mets = self._expl_behavior.train(post, feat, embed, reward)
else:
mets = self._expl_behavior.train(post, feat, embed)
metrics.update({'expl_' + key: value for key, value in mets.items()})
# if tf.distribute.get_replica_context().replica_id_in_sync_group == 0:
# if self._config.pred_discount:
# metrics.update(self._episode_end_precision_recall(data, feat))
for name, value in metrics.items():
self._metrics[name].update_state(value)
# def _episode_end_precision_recall(self, data, feat):
# pred = self._world_model.heads['discount'](feat).mean()
# ppos, pneg = pred <= 0.5, pred > 0.5
# dpos, dneg = data['discount'] <= 0.5, data['discount'] > 0.5
# tp = tf.reduce_sum(tf.cast(tf.logical_and(ppos, dpos), self._float))
# tn = tf.reduce_sum(tf.cast(tf.logical_and(pneg, dneg), self._float))
# fp = tf.reduce_sum(tf.cast(tf.logical_and(ppos, dneg), self._float))
# fn = tf.reduce_sum(tf.cast(tf.logical_and(pneg, dpos), self._float))
# tpr, tnr, ppv = tp / (tp + fn), tn / (tn + fp), tp / (tp + fp)
# ba, f1 = (tpr + tnr) / 2, 2 * ppv * tpr / (ppv + tpr)
# any_ = tf.cast(tf.reduce_any(dpos), self._float)
# metrics = dict(
# ee_tp=tp, ee_tn=tn, ee_fp=fp, ee_fn=fn, ee_tpr=tpr, ee_tnr=tnr,
# ee_ppv=ppv, ee_ba=ba, ee_f1=f1, ee_any=any_)
# return {
# k: tf.where(tf.math.is_nan(v), tf.zeros_like(v), v)
# for k, v in metrics.items()}
def count_steps(directory):
return sum(int(str(n).split('-')[-1][:-4]) - 1 for n in directory.glob('*.npz'))
def make_dataset(episodes, config):
example = episodes[next(iter(episodes.keys()))]
example = tools.decompress_episode(example)
types = {k: v.dtype for k, v in example.items()}
shapes = {k: (None,) + v.shape[1:] for k, v in example.items()}
generator = lambda: tools.sample_episodes(
episodes, config.batch_length, config.oversample_ends)
dataset = tf.data.Dataset.from_generator(generator, types, shapes)
dataset = dataset.batch(config.batch_size, drop_remainder=True)
dataset = dataset.prefetch(10)
return dataset
def make_env(config, logger, mode, train_eps, eval_eps):
suite, task = config.task.split('_', 1)
if suite == 'dmc' or suite == 'dmctoy':
env = wrappers.DeepMindControl(task, config.action_repeat, config.size)
env = wrappers.NormalizeActions(env)
elif suite == 'atari':
env = wrappers.Atari(
task, config.action_repeat, config.size,
grayscale=config.atari_grayscale, sticky_actions=True,
life_done=config.atari_lifes and (mode == 'train'))
env = wrappers.OneHotAction(env)
elif suite == 'dmlab':
args = task, config.action_repeat, config.size
env = wrappers.DeepMindLab(*args, path=config.dmlab_path)
env = wrappers.OneHotAction(env)
elif suite == 'mc':
args = task, mode, config.size, config.action_repeat
env = wrappers.Minecraft(*args)
env = wrappers.OneHotAction(env)
else:
raise NotImplementedError(suite)
env = wrappers.TimeLimit(env, config.time_limit)
callbacks = [functools.partial(
process_episode, config, logger, mode, train_eps, eval_eps)]
env = wrappers.Collect(env, callbacks, config.precision)
env = wrappers.RewardObs(env)
return env
def process_episode(config, logger, mode, train_eps, eval_eps, episode):
directory = dict(train=config.traindir, eval=config.evaldir)[mode]
cache = dict(train=train_eps, eval=eval_eps)[mode]
filename = tools.save_episodes(directory, [episode])[0]
length = len(episode['reward']) - 1
score = float(episode['reward'].astype(np.float32).sum())
video = episode['image']
if mode == 'eval':
if config.deleps:
for key, value in list(cache.items()):
del cache[key]
del value
cache.clear()
if mode == 'train' and config.dataset_size:
total = 0
for key, ep in reversed(sorted(cache.items(), key=lambda x: x[0])):
if total <= config.dataset_size - length:
total += len(ep['reward']) - 1
else:
del cache[key]
if config.deleps:
del ep
logger.scalar('dataset_size', total + length)
if config.compress_dataset:
episode = tools.compress_episode(episode, (np.uint8,))
cache[str(filename)] = episode
print(f'{mode.title()} episode has {length} steps and return {score:.1f}.')
logger.scalar(f'{mode}_return', score)
logger.scalar(f'{mode}_length', length)
logger.scalar(f'{mode}_episodes', len(cache))
if mode == 'eval' or config.expl_gifs:
logger.video(f'{mode}_policy', video[None])
logger.write()
def main(config):
if config.load_config:
from configs import default_configs
suite, _ = config.task.split('_', 1)
default = default_configs[suite]
config = vars(config)
for key, value in default.items():
config[key] = value
config = tools.AttrDict(config)
print(config)
config.traindir = config.traindir or config.logdir / 'train_eps'
config.evaldir = config.evaldir or config.logdir / 'eval_eps'
config.steps //= config.action_repeat
config.eval_every //= config.action_repeat
config.log_every //= config.action_repeat
config.backup_every //= config.action_repeat
config.time_limit //= config.action_repeat
config.act = getattr(tf.nn, config.act)
if config.empow_entropy == 'auto':
options = dict(
tanh_normal='sample', onehot='sample', discretized='analytic')
config.empow_entropy = options[config.actor_dist]
if config.debug:
tf.config.experimental_run_functions_eagerly(True)
if config.gpu_growth:
for gpu in tf.config.experimental.list_physical_devices('GPU'):
tf.config.experimental.set_memory_growth(gpu, True)
assert config.precision in (16, 32), config.precision
if config.precision == 16:
prec.set_policy(prec.Policy('mixed_float16'))
print('Logdir', config.logdir)
config.logdir.mkdir(parents=True, exist_ok=True)
step = count_steps(config.traindir)
logger = tools.Logger(config.logdir, config.action_repeat * step)
print('Create envs.')
train_eps = tools.load_episodes(config.traindir, limit=config.dataset_size)
eval_eps = tools.load_episodes(config.evaldir, limit=1)
make1 = lambda mode: make_env(config, logger, mode, train_eps, eval_eps)
make2 = lambda mode: wrappers.Async(lambda: make1(mode), config.parallel)
train_envs = [make2('train') for _ in range(config.envs)]
eval_envs = [make2('eval') for _ in range(config.envs)]
acts = train_envs[0].action_space
config.num_actions = acts.n if hasattr(acts, 'n') else acts.shape[0]
prefill = max(0, config.prefill - count_steps(config.traindir))
print(f'Prefill dataset ({prefill} steps).')
random_agent = lambda o, d, s: ([acts.sample() for _ in d], s)
tools.simulate(random_agent, train_envs, prefill)
tools.simulate(random_agent, eval_envs, episodes=1)
logger.step = config.action_repeat * count_steps(config.traindir)
print('Simulate agent.')
train_dataset = make_dataset(train_eps, config)
eval_dataset = iter(make_dataset(eval_eps, config))
agent = Dreamer(config, logger, train_dataset)
if (config.logdir / 'variables.pkl').exists():
try:
agent.load(config.logdir / 'variables.pkl')
except:
agent.load(config.logdir / 'variables-backup.pkl')
agent._should_pretrain._once = False
state = None
should_backup = tools.Every(config.backup_every)
while agent._step.numpy().item() < config.steps:
print('Process statistics.')
if config.gc:
gc.collect()
logger.scalar('python_objects', len(gc.get_objects()))
with open('/proc/self/statm') as f:
ram = int(f.read().split()[1]) * resource.getpagesize() / 1024 / 1024 / 1024
logger.scalar('ram', ram)
logger.write()
print('Start evaluation.')
video_pred = agent._world_model.video_pred(next(eval_dataset))
logger.video('eval_openl', video_pred)
eval_policy = functools.partial(agent, training=False)
tools.simulate(eval_policy, eval_envs, episodes=1)
print('Start training.')
state = tools.simulate(agent, train_envs, config.eval_every, state=state)
if config.ckpt:
agent.save(config.ckptdir / 'variables.pkl')
agent.save(config.logdir / 'variables.pkl')
if should_backup(agent._step.numpy().item()):
if config.ckpt:
agent.save(config.ckptdir / 'variables-backup.pkl')
agent.save(config.logdir / 'variables-backup.pkl')
for env in train_envs + eval_envs:
env.close()
if __name__ == '__main__':
print(os.getpid())
try:
import colored_traceback
colored_traceback.add_hook()
except ImportError:
pass
parser = argparse.ArgumentParser()
for key, value in define_config().items():
parser.add_argument(f'--{key}', type=tools.args_type(value), default=value)
main(parser.parse_args())
class EmpowStepActionBehavior(tools.Module): # I(S';A|S), action space
def __init__(self, config, world_model, behavior):
self._config = config
self._world_model = world_model
self._behavior = behavior
self.actor = self._behavior.actor
self._tau = config.temperature # temperature
self._plan_ss = networks.ActionHead(
config.num_actions, config.actor_layers, config.units, config.act,
config.actor_dist, config.actor_init_std, config.actor_disc) # p(a_0|s, s')
if config.reuse_actor:
self._source = self._behavior.actor
else:
self._source = networks.ActionHead(
config.num_actions, config.actor_layers, config.units, config.act,
config.actor_dist, config.actor_init_std, config.actor_disc)
self._offset = networks.DenseHead(
[], config.actor_layers, config.units, config.act)
self._plan_opt = tools.Adam(
'plan', config.actor_lr, config.opt_eps, config.grad_clip,
config.weight_decay)
self._source_opt = tools.Adam(
'source', config.actor_lr, config.opt_eps, config.grad_clip,
config.weight_decay) # for both source and offset
def train(self, start, feat, embed, task_reward=None):
metrics = {}
metrics.update(self._behavior.train(start, lambda f, s, a: self._intrinsic_reward(f, s, a, task_reward)))
metrics.update(self._train_step(start))
return metrics
def _intrinsic_reward(self, feat, state, action, task_reward=None):
reward = self._tau * self._offset(feat).mode()
if task_reward:
t_reward = task_reward(feat, 0, 0)
reward = dict(expl=reward, task=t_reward)
return reward
def _imagine(self, start): # don't use this for self._behavior.train
flatten = lambda x: tf.reshape(x, [-1] + list(x.shape[2:]))
start = {k: flatten(v) for k, v in start.items()}
def step(prev, _):
state, _, _ = prev
feat = self._world_model.dynamics.get_feat(state)
if self._config.dropout_imag:
feat = tf.nn.dropout(feat, self._config.dropout_imag)
action = self._source(tf.stop_gradient(feat)).sample() # source instead of actor
state = self._world_model.dynamics.img_step(state, action)
return state, feat, action
range_ = tf.range(self._config.imag_horizon)
feat = 0 * self._world_model.dynamics.get_feat(start)
action = 0 * self._source(feat).mean()
states, feats, actions = tools.static_scan(step, range_, (start, feat, action))
return feats, states, actions
def _train_step(self, start):
feat, _, action = self._imagine(start)
ss_feat = tf.concat([feat[:-1], feat[1:]], axis=-1) # time-1, batch, features
action = tf.stop_gradient(tf.cast(action, tf.float32))
metrics = {}
with tf.GradientTape() as tape:
plan_pred = self._plan_ss(ss_feat, dtype=tf.float32)
log_plan_prob = plan_pred.log_prob(action[:-1])
loss = -tf.reduce_mean(log_plan_prob)
norm = self._plan_opt(tape, loss, [self._plan_ss])
metrics['plan_grad_norm'] = norm
metrics['plan_loss'] = loss
if self._config.precision == 16:
metrics['plan_dynamics_loss_scale'] = self._plan_opt.loss_scale
with tf.GradientTape() as tape:
offset = self._offset(feat[:-1], dtype=tf.float32).mode()
action_pred = self._source(feat[:-1], dtype=tf.float32)
log_action_prob = action_pred.log_prob(action[:-1])
r_sa = log_action_prob + offset
loss = tf.reduce_mean(tf.square(1./self._tau * tf.stop_gradient(log_plan_prob) - r_sa))
norm = self._source_opt(tape, loss, [self._source, self._offset])
metrics['source_grad_norm'] = norm
metrics['source_loss'] = loss
if self._config.precision == 16:
metrics['source_loss_scale'] = self._source_opt.loss_scale
return metrics
class EmpowVIMBehavior(tools.Module): # I(S^T;A|S), action space, n-step
def __init__(self, config, world_model, behavior):
self._config = config
self._world_model = world_model
self._behavior = behavior
self.actor = self._behavior.actor
self._tau = config.temperature # temperature
self._plan = networks.ActionHead(
config.num_actions, config.actor_layers, config.units, config.act,
config.actor_dist, config.actor_init_std, config.actor_disc) # p(a_1:t|s, s')
if config.reuse_actor:
self._source = self._behavior.actor
else:
self._source = networks.ActionHead(
config.num_actions, config.actor_layers, config.units, config.act,
config.actor_dist, config.actor_init_std, config.actor_disc) # p(a_1:t|s)
self._offset = networks.DenseHead(
[], config.actor_layers, config.units, config.act)
self._plan_opt = tools.Adam(
'plan', config.actor_lr, config.opt_eps, config.grad_clip,
config.weight_decay)
self._source_opt = tools.Adam(
'source', config.actor_lr, config.opt_eps, config.grad_clip,
config.weight_decay) # for both source and offset
def train(self, start, feat, embed, task_reward=None):
metrics = {}
metrics.update(self._behavior.train(start, lambda f, s, a: self._intrinsic_reward(f, s, a, task_reward)))
metrics.update(self._train_vim(start))
return metrics
def _intrinsic_reward(self, feat, state, action, task_reward=None):
reward = self._tau * self._offset(feat).mode()
if task_reward:
t_reward = task_reward(feat, 0, 0)
reward = dict(expl=reward, task=t_reward)
return reward
def _imagine(self, start, actions):
def step(prev, action):
state, _ = prev
feat = self._world_model.dynamics.get_feat(state)
if self._config.dropout_imag:
feat = tf.nn.dropout(feat, self._config.dropout_imag)
action = tf.cast(action, feat.dtype)
state = self._world_model.dynamics.img_step(state, action)
return state, feat
feat = self._world_model.dynamics.get_feat(start)
states, feats = tools.static_scan(step, actions, (start, feat))
return feats[-1]
def _train_vim(self, start):
flatten = lambda x: tf.reshape(x, [-1] + list(x.shape[2:]))
start = {k: flatten(v) for k, v in start.items()}
feat = self._world_model.dynamics.get_feat(start)
metrics = {}
# Get the plan
action_pred, action_list = self._source(feat, dtype=tf.float32, horizon=self._config.empow_horizon)
# Execute it in the model and retreive the last state
last_feat = self._imagine(start, tf.stack(action_list, axis=0))
with tf.GradientTape() as tape:
plan_pred, _ = self._plan(tf.concat([feat, last_feat], axis=-1), dtype=tf.float32, horizon=self._config.empow_horizon)
log_plan_prob = [pp.log_prob(tf.stop_gradient(a)) for a, pp in zip(action_list, plan_pred)]
log_plan_prob = tf.reduce_mean(tf.stack(log_plan_prob, axis=0), axis=0)
loss = -tf.reduce_mean(log_plan_prob)
norm = self._plan_opt(tape, loss, [self._plan])
metrics['plan_grad_norm'] = norm
metrics['plan_loss'] = loss
if self._config.precision == 16:
metrics['plan_loss_scale'] = self._plan_opt.loss_scale
with tf.GradientTape() as tape:
offset = self._offset(feat, dtype=tf.float32).mode()
log_action_prob = [ap.log_prob(tf.stop_gradient(a)) for a, ap in zip(action_list, action_pred)]
log_action_prob = tf.reduce_mean(tf.stack(log_action_prob, axis=0), axis=0)
r_sa = log_action_prob + offset
loss = tf.reduce_mean(tf.square(1./self._tau * tf.stop_gradient(log_plan_prob) - r_sa))
norm = self._source_opt(tape, loss, [self._source, self._offset])
metrics['source_grad_norm'] = norm
metrics['source_loss'] = loss
if self._config.precision == 16:
metrics['source_loss_scale'] = self._source_opt.loss_scale
return metrics
class ActionHead(tools.Module):
def __init__(
self, size, layers, units, act=tf.nn.elu, dist='tanh_normal',
init_std=5.0, min_std=1e-4, mean_scale=5.0, action_disc=5):
self._size = size
self._layers = layers
self._units = units
self._dist = dist
self._act = act
self._min_std = min_std
self._init_std = init_std
self._mean_scale = mean_scale
self._action_disc = action_disc
self._cell = tfkl.GRUCell(self._units)
def __call__(self, features, dtype=None, horizon=0):
x = features
for index in range(self._layers):
x = self.get(f'h{index}', tfkl.Dense, self._units, self._act)(x)
h = x # init hidden state
act_dist = [self.action(h, dtype)]
if horizon == 0:
self.hidden = h
return act_dist[0]
act = [act_dist[-1].sample()]
self.hidden = [h]
for _ in range(horizon-1):
h, _ = self._cell(act[-1], [h])
act_dist.append(self.action(h, dtype))
act.append(act_dist[-1].sample())
self.hidden.append(h)
return act_dist, act
def action(self, features, dtype):
raw_init_std = np.log(np.exp(self._init_std) - 1)
x = features
if self._dist == 'tanh_normal':
# https://www.desmos.com/calculator/rcmcf5jwe7
x = self.get(f'hout', tfkl.Dense, 2 * self._size)(x)
if dtype:
x = tf.cast(x, dtype)
mean, std = tf.split(x, 2, -1)
mean = self._mean_scale * tf.tanh(mean / self._mean_scale)
std = tf.nn.softplus(std + raw_init_std) + self._min_std
dist = tfd.Normal(mean, std)
dist = tfd.TransformedDistribution(dist, tools.TanhBijector())
dist = tfd.Independent(dist, 1)
dist = tools.SampleDist(dist)
elif self._dist == 'normal':
x = self.get(f'hout', tfkl.Dense, 2 * self._size)(x)
if dtype:
x = tf.cast(x, dtype)
mean, std = tf.split(x, 2, -1)
mean = self._mean_scale * tf.tanh(mean / self._mean_scale)
std = tf.nn.softplus(std + raw_init_std) + self._min_std
dist = tfd.Normal(mean, std)
dist = tfd.Independent(dist, 1)
elif self._dist == 'onehot':
x = self.get(f'hout', tfkl.Dense, self._size)(x)
if dtype:
x = tf.cast(x, dtype)
dist = tools.OneHotDist(x, dtype=dtype)
elif self._dist == 'discretized':
x = self.get(f'hout', tfkl.Dense, self._size * self._action_disc)(x)
x = tf.reshape(x, x.shape[:-1] + [self._size, self._action_disc])
if dtype:
x = tf.cast(x, dtype)
dist = tfd.Independent(tools.UnifDiscDist(x, dtype=dtype), 1)
else:
raise NotImplementedError(dist)
return dist
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment