Skip to content

Instantly share code, notes, and snippets.

@alexlimh
Last active April 29, 2020 14:02
Show Gist options
  • Save alexlimh/3af5563d6367de54d2f37d301d62a48a to your computer and use it in GitHub Desktop.
Save alexlimh/3af5563d6367de54d2f37d301d62a48a to your computer and use it in GitHub Desktop.
dreamer.py: add config.reuse_actor; add EmpowStepActionBehavior models.py: add EmpowStepActionBehavior; change EmpowVIMBehavior; networks.py: change ActionHead class.
import argparse
import collections
import functools
import gc
import os
import pathlib
import resource
import sys
import warnings
warnings.filterwarnings('ignore', '.*box bound precision lowered.*')
# os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
os.environ['MUJOCO_GL'] = 'egl'
import numpy as np
import tensorflow as tf
from tensorflow.keras.mixed_precision import experimental as prec
# tf.get_logger().setLevel('ERROR')
from tensorflow_probability import distributions as tfd
sys.path.append(str(pathlib.Path(__file__).parent))
import models
import tools
import wrappers
def define_config():
config = tools.AttrDict()
# General.
config.load_config = False
config.ckpt = True
config.logdir = pathlib.Path('.')
config.ckptdir = pathlib.Path('.')
config.traindir = None
config.evaldir = None
config.seed = 0
config.steps = 1e7
config.eval_every = 1e4
config.log_every = 1e4
config.backup_every = 1e3
config.gpu_growth = True
config.precision = 16
config.debug = False
config.expl_gifs = False
config.compress_dataset = False
config.gc = False
config.deleps = False
# Environment.
config.task = 'dmc_walker_walk'
config.size = (64, 64)
config.parallel = 'none'
config.envs = 1
config.action_repeat = 2
config.time_limit = 1000
config.prefill = 5000
config.eval_noise = 0.0
config.clip_rewards = 'identity'
config.atari_grayscale = False
config.atari_lifes = False
config.atari_fire_start = False
config.dmlab_path = None
# Model.
config.dynamics = 'rssm'
config.grad_heads = ('image', 'reward')
config.deter_size = 200
config.stoch_size = 50
config.stoch_depth = 10
config.dynamics_deep = False
config.units = 400
config.reward_layers = 2
config.discount_layers = 3
config.value_layers = 3
config.actor_layers = 4
config.act = 'elu'
config.cnn_depth = 32
config.encoder_kernels = (4, 4, 4, 4)
config.decoder_kernels = (5, 5, 6, 6)
config.decoder_thin = True
config.cnn_embed_size = 0
config.kl_scale = 1.0
config.free_nats = 3.0
config.kl_balance = 0.8
config.pred_discount = False
config.discount_rare_bonus = 0.0
config.discount_scale = 10.0
config.reward_scale = 2.0
config.weight_decay = 0.0
config.proprio = False
# Training.
config.batch_size = 50
config.batch_length = 50
config.train_every = 500
config.train_steps = 100
config.pretrain = 100
config.model_lr = 6e-4
config.value_lr = 8e-5
config.actor_lr = 8e-5
config.grad_clip = 100.0
config.opt_eps = 1e-5
config.value_grad_clip = 100.0
config.actor_grad_clip = 100.0
config.dataset_size = 0
config.dropout_embed = 0.0
config.dropout_feat = 0.0
config.dropout_imag = 0.0
config.oversample_ends = False
config.slow_value_target = False
config.slow_actor_target = False
config.slow_target_update = 0
config.slow_target_soft = False
config.slow_target_moving_weight = 0.1
# Behavior.
config.discount = 0.99
config.discount_lambda = 0.95
config.imag_horizon = 15
config.actor_dist = 'tanh_normal'
config.actor_disc = 5
config.actor_entropy = 0.0
config.actor_init_std = 5.0
config.expl = 'additive_gaussian'
config.expl_amount = 0.3
config.eval_state_mean = False
config.reuse_task_behavior = False
config.task_expl_reward_weights = 0.5
# Exploration.
config.expl_behavior = 'greedy'
config.expl_until = 0
config.disag_target = 'stoch' # embed
config.disag_scale = 5.0
config.disag_log = True
config.disag_samples = 10
config.disag_models = 10
config.disag_offset = 1
config.disag_layers = 4
config.vim_layers = 4
config.disag_units = 400
config.disag_project = 0 # 100
config.disag_noise_scale = 10.0
config.disag_save_memory = True
config.empow_scale = 1.0
config.empow_entropy = 'auto'
config.empow_marginal = 'sequences'
config.empow_sequences = 5 # 3
config.empow_samples = 5 # 3
config.temperature = 1.0
config.empow_horizon = 1
config.reuse_actor = False
return config
class Dreamer(tools.Module):
def __init__(self, config, logger, dataset):
self._config = config
self._logger = logger
self._float = prec.global_policy().compute_dtype
self._should_log = tools.Every(config.log_every)
self._should_train = tools.Every(config.train_every)
self._should_pretrain = tools.Once()
self._should_expl = tools.Until(int(config.expl_until / config.action_repeat))
self._metrics = collections.defaultdict(tf.metrics.Mean)
with tf.device('cpu:0'):
self._step = tf.Variable(count_steps(config.traindir), dtype=tf.int64)
# self._strategy = tf.distribute.MirroredStrategy()
# with self._strategy.scope():
# self._dataset = self._strategy.experimental_distribute_dataset(dataset)
self._dataset = iter(dataset)
self._world_model = models.WorldModel(config)
self._behavior = models.ImagBehavior(config, self._world_model)
if not config.reuse_task_behavior:
behavior = models.ImagBehavior(config, self._world_model)
else:
behavior = self._behavior
self._expl_behavior = dict(
greedy=lambda: self._behavior,
random=lambda: models.RandomBehavior(config),
disag=lambda: models.DisagBehavior(config, self._world_model, behavior),
disag_info=lambda: models.DisagInfoBehavior(config, self._world_model, behavior),
disag_noise=lambda: models.DisagNoiseBehavior(config, self._world_model, behavior),
empow_open=lambda: models.EmpowOpenBehavior(config, self._world_model, behavior),
empow_action=lambda: models.EmpowActionBehavior(config, self._world_model, behavior),
empow_state=lambda: models.EmpowStateBehavior(config, self._world_model, behavior),
empow_step=lambda: models.EmpowStepBehavior(config, self._world_model, behavior),
empow_step_action=lambda: models.EmpowStepActionBehavior(config, self._world_model, behavior),
empow_vim=lambda: models.EmpowVIMBehavior(config, self._world_model, behavior),
)[config.expl_behavior]()
# Train step to initialize variables including optimizer statistics.
self.train(next(self._dataset))
def __call__(self, obs, reset, state=None, training=True):
step = self._step.numpy().item()
if state is not None and reset.any():
mask = tf.cast(1 - reset, self._float)[:, None]
state = tf.nest.map_structure(lambda x: x * mask, state)
if training and self._should_train(step):
# with self._strategy.scope():
steps = (
self._config.pretrain if self._should_pretrain()
else self._config.train_steps)
for _ in range(steps):
self.train(next(self._dataset))
if self._should_log(step):
for name, mean in self._metrics.items():
self._logger.scalar(name, float(mean.result()))
mean.reset_states()
openl = self._world_model.video_pred(next(self._dataset))
self._logger.video('train_openl', openl)
self._logger.write(fps=True)
action, state = self._policy(obs, state, training)
if training:
self._step.assign_add(len(reset))
self._logger.step = self._config.action_repeat * self._step.numpy().item()
return action, state
@tf.function
def _policy(self, obs, state, training):
if state is None:
batch_size = len(obs['image'])
latent = self._world_model.dynamics.initial(len(obs['image']))
action = tf.zeros((batch_size, self._config.num_actions), self._float)
else:
latent, action = state
embed = self._world_model.encoder(self._world_model.preprocess(obs))
latent, _ = self._world_model.dynamics.obs_step(latent, action, embed)
if self._config.eval_state_mean:
latent['stoch'] = latent['mean']
feat = self._world_model.dynamics.get_feat(latent)
if not training:
action = self._behavior.actor(feat).mode()
elif self._should_expl(self._step):
action = self._expl_behavior.actor(feat).sample()
else:
action = self._behavior.actor(feat).sample()
action = self._exploration(action, training)
state = (latent, action)
return action, state
def _exploration(self, action, training):
amount = self._config.expl_amount if training else self._config.eval_noise
if amount == 0:
return action
amount = tf.cast(amount, self._float)
if self._config.expl == 'additive_gaussian':
return tf.clip_by_value(tfd.Normal(action, amount).sample(), -1, 1)
if self._config.expl == 'epsilon_greedy':
probs = amount / self._config.num_actions + (1 - amount) * action
return tools.OneHotDist(probs=probs).sample()
raise NotImplementedError(self._config.expl)
# @tf.function
# def train(self, data):
# self._strategy.experimental_run_v2(self._train, (data,))
@tf.function
def train(self, data):
metrics = {}
embed, post, feat, mets = self._world_model.train(data)
metrics.update(mets)
reward = lambda f, s, a: self._world_model.heads['reward'](f).mode()
if not self._config.reuse_task_behavior:
metrics.update(self._behavior.train(post, reward))
if self._config.expl_behavior != 'greedy':
if self._config.reuse_task_behavior:
mets = self._expl_behavior.train(post, feat, embed, reward)
else:
mets = self._expl_behavior.train(post, feat, embed)
metrics.update({'expl_' + key: value for key, value in mets.items()})
# if tf.distribute.get_replica_context().replica_id_in_sync_group == 0:
# if self._config.pred_discount:
# metrics.update(self._episode_end_precision_recall(data, feat))
for name, value in metrics.items():
self._metrics[name].update_state(value)
# def _episode_end_precision_recall(self, data, feat):
# pred = self._world_model.heads['discount'](feat).mean()
# ppos, pneg = pred <= 0.5, pred > 0.5
# dpos, dneg = data['discount'] <= 0.5, data['discount'] > 0.5
# tp = tf.reduce_sum(tf.cast(tf.logical_and(ppos, dpos), self._float))
# tn = tf.reduce_sum(tf.cast(tf.logical_and(pneg, dneg), self._float))
# fp = tf.reduce_sum(tf.cast(tf.logical_and(ppos, dneg), self._float))
# fn = tf.reduce_sum(tf.cast(tf.logical_and(pneg, dpos), self._float))
# tpr, tnr, ppv = tp / (tp + fn), tn / (tn + fp), tp / (tp + fp)
# ba, f1 = (tpr + tnr) / 2, 2 * ppv * tpr / (ppv + tpr)
# any_ = tf.cast(tf.reduce_any(dpos), self._float)
# metrics = dict(
# ee_tp=tp, ee_tn=tn, ee_fp=fp, ee_fn=fn, ee_tpr=tpr, ee_tnr=tnr,
# ee_ppv=ppv, ee_ba=ba, ee_f1=f1, ee_any=any_)
# return {
# k: tf.where(tf.math.is_nan(v), tf.zeros_like(v), v)
# for k, v in metrics.items()}
def count_steps(directory):
return sum(int(str(n).split('-')[-1][:-4]) - 1 for n in directory.glob('*.npz'))
def make_dataset(episodes, config):
example = episodes[next(iter(episodes.keys()))]
example = tools.decompress_episode(example)
types = {k: v.dtype for k, v in example.items()}
shapes = {k: (None,) + v.shape[1:] for k, v in example.items()}
generator = lambda: tools.sample_episodes(
episodes, config.batch_length, config.oversample_ends)
dataset = tf.data.Dataset.from_generator(generator, types, shapes)
dataset = dataset.batch(config.batch_size, drop_remainder=True)
dataset = dataset.prefetch(10)
return dataset
def make_env(config, logger, mode, train_eps, eval_eps):
suite, task = config.task.split('_', 1)
if suite == 'dmc' or suite == 'dmctoy':
env = wrappers.DeepMindControl(task, config.action_repeat, config.size)
env = wrappers.NormalizeActions(env)
elif suite == 'atari':
env = wrappers.Atari(
task, config.action_repeat, config.size,
grayscale=config.atari_grayscale, sticky_actions=True,
life_done=config.atari_lifes and (mode == 'train'))
env = wrappers.OneHotAction(env)
elif suite == 'dmlab':
args = task, config.action_repeat, config.size
env = wrappers.DeepMindLab(*args, path=config.dmlab_path)
env = wrappers.OneHotAction(env)
elif suite == 'mc':
args = task, mode, config.size, config.action_repeat
env = wrappers.Minecraft(*args)
env = wrappers.OneHotAction(env)
else:
raise NotImplementedError(suite)
env = wrappers.TimeLimit(env, config.time_limit)
callbacks = [functools.partial(
process_episode, config, logger, mode, train_eps, eval_eps)]
env = wrappers.Collect(env, callbacks, config.precision)
env = wrappers.RewardObs(env)
return env
def process_episode(config, logger, mode, train_eps, eval_eps, episode):
directory = dict(train=config.traindir, eval=config.evaldir)[mode]
cache = dict(train=train_eps, eval=eval_eps)[mode]
filename = tools.save_episodes(directory, [episode])[0]
length = len(episode['reward']) - 1
score = float(episode['reward'].astype(np.float32).sum())
video = episode['image']
if mode == 'eval':
if config.deleps:
for key, value in list(cache.items()):
del cache[key]
del value
cache.clear()
if mode == 'train' and config.dataset_size:
total = 0
for key, ep in reversed(sorted(cache.items(), key=lambda x: x[0])):
if total <= config.dataset_size - length:
total += len(ep['reward']) - 1
else:
del cache[key]
if config.deleps:
del ep
logger.scalar('dataset_size', total + length)
if config.compress_dataset:
episode = tools.compress_episode(episode, (np.uint8,))
cache[str(filename)] = episode
print(f'{mode.title()} episode has {length} steps and return {score:.1f}.')
logger.scalar(f'{mode}_return', score)
logger.scalar(f'{mode}_length', length)
logger.scalar(f'{mode}_episodes', len(cache))
if mode == 'eval' or config.expl_gifs:
logger.video(f'{mode}_policy', video[None])
logger.write()
def main(config):
if config.load_config:
from configs import default_configs
suite, _ = config.task.split('_', 1)
default = default_configs[suite]
config = vars(config)
for key, value in default.items():
config[key] = value
config = tools.AttrDict(config)
print(config)
config.traindir = config.traindir or config.logdir / 'train_eps'
config.evaldir = config.evaldir or config.logdir / 'eval_eps'
config.steps //= config.action_repeat
config.eval_every //= config.action_repeat
config.log_every //= config.action_repeat
config.backup_every //= config.action_repeat
config.time_limit //= config.action_repeat
config.act = getattr(tf.nn, config.act)
if config.empow_entropy == 'auto':
options = dict(
tanh_normal='sample', onehot='sample', discretized='analytic')
config.empow_entropy = options[config.actor_dist]
if config.debug:
tf.config.experimental_run_functions_eagerly(True)
if config.gpu_growth:
for gpu in tf.config.experimental.list_physical_devices('GPU'):
tf.config.experimental.set_memory_growth(gpu, True)
assert config.precision in (16, 32), config.precision
if config.precision == 16:
prec.set_policy(prec.Policy('mixed_float16'))
print('Logdir', config.logdir)
config.logdir.mkdir(parents=True, exist_ok=True)
step = count_steps(config.traindir)
logger = tools.Logger(config.logdir, config.action_repeat * step)
print('Create envs.')
train_eps = tools.load_episodes(config.traindir, limit=config.dataset_size)
eval_eps = tools.load_episodes(config.evaldir, limit=1)
make1 = lambda mode: make_env(config, logger, mode, train_eps, eval_eps)
make2 = lambda mode: wrappers.Async(lambda: make1(mode), config.parallel)
train_envs = [make2('train') for _ in range(config.envs)]
eval_envs = [make2('eval') for _ in range(config.envs)]
acts = train_envs[0].action_space
config.num_actions = acts.n if hasattr(acts, 'n') else acts.shape[0]
prefill = max(0, config.prefill - count_steps(config.traindir))
print(f'Prefill dataset ({prefill} steps).')
random_agent = lambda o, d, s: ([acts.sample() for _ in d], s)
tools.simulate(random_agent, train_envs, prefill)
tools.simulate(random_agent, eval_envs, episodes=1)
logger.step = config.action_repeat * count_steps(config.traindir)
print('Simulate agent.')
train_dataset = make_dataset(train_eps, config)
eval_dataset = iter(make_dataset(eval_eps, config))
agent = Dreamer(config, logger, train_dataset)
if (config.logdir / 'variables.pkl').exists():
try:
agent.load(config.logdir / 'variables.pkl')
except:
agent.load(config.logdir / 'variables-backup.pkl')
agent._should_pretrain._once = False
state = None
should_backup = tools.Every(config.backup_every)
while agent._step.numpy().item() < config.steps:
print('Process statistics.')
if config.gc:
gc.collect()
logger.scalar('python_objects', len(gc.get_objects()))
with open('/proc/self/statm') as f:
ram = int(f.read().split()[1]) * resource.getpagesize() / 1024 / 1024 / 1024
logger.scalar('ram', ram)
logger.write()
print('Start evaluation.')
video_pred = agent._world_model.video_pred(next(eval_dataset))
logger.video('eval_openl', video_pred)
eval_policy = functools.partial(agent, training=False)
tools.simulate(eval_policy, eval_envs, episodes=1)
print('Start training.')
state = tools.simulate(agent, train_envs, config.eval_every, state=state)
if config.ckpt:
agent.save(config.ckptdir / 'variables.pkl')
agent.save(config.logdir / 'variables.pkl')
if should_backup(agent._step.numpy().item()):
if config.ckpt:
agent.save(config.ckptdir / 'variables-backup.pkl')
agent.save(config.logdir / 'variables-backup.pkl')
for env in train_envs + eval_envs:
env.close()
if __name__ == '__main__':
print(os.getpid())
try:
import colored_traceback
colored_traceback.add_hook()
except ImportError:
pass
parser = argparse.ArgumentParser()
for key, value in define_config().items():
parser.add_argument(f'--{key}', type=tools.args_type(value), default=value)
main(parser.parse_args())
class EmpowStepActionBehavior(tools.Module): # I(S';A|S), action space
def __init__(self, config, world_model, behavior):
self._config = config
self._world_model = world_model
self._behavior = behavior
self.actor = self._behavior.actor
self._tau = config.temperature # temperature
self._plan_ss = networks.ActionHead(
config.num_actions, config.actor_layers, config.units, config.act,
config.actor_dist, config.actor_init_std, config.actor_disc) # p(a_0|s, s')
if config.reuse_actor:
self._source = self._behavior.actor
else:
self._source = networks.ActionHead(
config.num_actions, config.actor_layers, config.units, config.act,
config.actor_dist, config.actor_init_std, config.actor_disc)
self._offset = networks.DenseHead(
[], config.actor_layers, config.units, config.act)
self._plan_opt = tools.Adam(
'plan', config.actor_lr, config.opt_eps, config.grad_clip,
config.weight_decay)
self._source_opt = tools.Adam(
'source', config.actor_lr, config.opt_eps, config.grad_clip,
config.weight_decay) # for both source and offset
def train(self, start, feat, embed, task_reward=None):
metrics = {}
metrics.update(self._behavior.train(start, lambda f, s, a: self._intrinsic_reward(f, s, a, task_reward)))
metrics.update(self._train_step(start))
return metrics
def _intrinsic_reward(self, feat, state, action, task_reward=None):
reward = self._tau * self._offset(feat).mode()
if task_reward:
t_reward = task_reward(feat, 0, 0)
reward = dict(expl=reward, task=t_reward)
return reward
def _imagine(self, start): # don't use this for self._behavior.train
flatten = lambda x: tf.reshape(x, [-1] + list(x.shape[2:]))
start = {k: flatten(v) for k, v in start.items()}
def step(prev, _):
state, _, _ = prev
feat = self._world_model.dynamics.get_feat(state)
if self._config.dropout_imag:
feat = tf.nn.dropout(feat, self._config.dropout_imag)
action = self._source(tf.stop_gradient(feat)).sample() # source instead of actor
state = self._world_model.dynamics.img_step(state, action)
return state, feat, action
range_ = tf.range(self._config.imag_horizon)
feat = 0 * self._world_model.dynamics.get_feat(start)
action = 0 * self._source(feat).mean()
states, feats, actions = tools.static_scan(step, range_, (start, feat, action))
return feats, states, actions
def _train_step(self, start):
feat, _, action = self._imagine(start)
ss_feat = tf.concat([feat[:-1], feat[1:]], axis=-1) # time-1, batch, features
action = tf.stop_gradient(tf.cast(action, tf.float32))
metrics = {}
with tf.GradientTape() as tape:
plan_pred = self._plan_ss(ss_feat, dtype=tf.float32)
log_plan_prob = plan_pred.log_prob(action[:-1])
loss = -tf.reduce_mean(log_plan_prob)
norm = self._plan_opt(tape, loss, [self._plan_ss])
metrics['plan_grad_norm'] = norm
metrics['plan_loss'] = loss
if self._config.precision == 16:
metrics['plan_dynamics_loss_scale'] = self._plan_opt.loss_scale
with tf.GradientTape() as tape:
offset = self._offset(feat[:-1], dtype=tf.float32).mode()
action_pred = self._source(feat[:-1], dtype=tf.float32)
log_action_prob = action_pred.log_prob(action[:-1])
r_sa = log_action_prob + offset
loss = tf.reduce_mean(tf.square(1./self._tau * tf.stop_gradient(log_plan_prob) - r_sa))
norm = self._source_opt(tape, loss, [self._source, self._offset])
metrics['source_grad_norm'] = norm
metrics['source_loss'] = loss
if self._config.precision == 16:
metrics['source_loss_scale'] = self._source_opt.loss_scale
return metrics
class EmpowVIMBehavior(tools.Module): # I(S^T;A|S), action space, n-step
def __init__(self, config, world_model, behavior):
self._config = config
self._world_model = world_model
self._behavior = behavior
self.actor = self._behavior.actor
self._tau = config.temperature # temperature
self._plan = networks.ActionHead(
config.num_actions, config.actor_layers, config.units, config.act,
config.actor_dist, config.actor_init_std, config.actor_disc) # p(a_1:t|s, s')
if config.reuse_actor:
self._source = self._behavior.actor
else:
self._source = networks.ActionHead(
config.num_actions, config.actor_layers, config.units, config.act,
config.actor_dist, config.actor_init_std, config.actor_disc) # p(a_1:t|s)
self._offset = networks.DenseHead(
[], config.actor_layers, config.units, config.act)
self._plan_opt = tools.Adam(
'plan', config.actor_lr, config.opt_eps, config.grad_clip,
config.weight_decay)
self._source_opt = tools.Adam(
'source', config.actor_lr, config.opt_eps, config.grad_clip,
config.weight_decay) # for both source and offset
def train(self, start, feat, embed, task_reward=None):
metrics = {}
metrics.update(self._behavior.train(start, lambda f, s, a: self._intrinsic_reward(f, s, a, task_reward)))
metrics.update(self._train_vim(start))
return metrics
def _intrinsic_reward(self, feat, state, action, task_reward=None):
reward = self._tau * self._offset(feat).mode()
if task_reward:
t_reward = task_reward(feat, 0, 0)
reward = dict(expl=reward, task=t_reward)
return reward
def _imagine(self, start, actions):
def step(prev, action):
state, _ = prev
feat = self._world_model.dynamics.get_feat(state)
if self._config.dropout_imag:
feat = tf.nn.dropout(feat, self._config.dropout_imag)
action = tf.cast(action, feat.dtype)
state = self._world_model.dynamics.img_step(state, action)
return state, feat
feat = self._world_model.dynamics.get_feat(start)
states, feats = tools.static_scan(step, actions, (start, feat))
return feats[-1]
def _train_vim(self, start):
flatten = lambda x: tf.reshape(x, [-1] + list(x.shape[2:]))
start = {k: flatten(v) for k, v in start.items()}
feat = self._world_model.dynamics.get_feat(start)
metrics = {}
# Get the plan
action_pred, action_list = self._source(feat, dtype=tf.float32, horizon=self._config.empow_horizon)
# Execute it in the model and retreive the last state
last_feat = self._imagine(start, tf.stack(action_list, axis=0))
with tf.GradientTape() as tape:
plan_pred, _ = self._plan(tf.concat([feat, last_feat], axis=-1), dtype=tf.float32, horizon=self._config.empow_horizon)
log_plan_prob = [pp.log_prob(tf.stop_gradient(a)) for a, pp in zip(action_list, plan_pred)]
log_plan_prob = tf.reduce_mean(tf.stack(log_plan_prob, axis=0), axis=0)
loss = -tf.reduce_mean(log_plan_prob)
norm = self._plan_opt(tape, loss, [self._plan])
metrics['plan_grad_norm'] = norm
metrics['plan_loss'] = loss
if self._config.precision == 16:
metrics['plan_loss_scale'] = self._plan_opt.loss_scale
with tf.GradientTape() as tape:
offset = self._offset(feat, dtype=tf.float32).mode()
log_action_prob = [ap.log_prob(tf.stop_gradient(a)) for a, ap in zip(action_list, action_pred)]
log_action_prob = tf.reduce_mean(tf.stack(log_action_prob, axis=0), axis=0)
r_sa = log_action_prob + offset
loss = tf.reduce_mean(tf.square(1./self._tau * tf.stop_gradient(log_plan_prob) - r_sa))
norm = self._source_opt(tape, loss, [self._source, self._offset])
metrics['source_grad_norm'] = norm
metrics['source_loss'] = loss
if self._config.precision == 16:
metrics['source_loss_scale'] = self._source_opt.loss_scale
return metrics
class ActionHead(tools.Module):
def __init__(
self, size, layers, units, act=tf.nn.elu, dist='tanh_normal',
init_std=5.0, min_std=1e-4, mean_scale=5.0, action_disc=5):
self._size = size
self._layers = layers
self._units = units
self._dist = dist
self._act = act
self._min_std = min_std
self._init_std = init_std
self._mean_scale = mean_scale
self._action_disc = action_disc
self._cell = tfkl.GRUCell(self._units)
def __call__(self, features, dtype=None, horizon=0):
x = features
for index in range(self._layers):
x = self.get(f'h{index}', tfkl.Dense, self._units, self._act)(x)
h = x # init hidden state
act_dist = [self.action(h, dtype)]
if horizon == 0:
self.hidden = h
return act_dist[0]
act = [act_dist[-1].sample()]
self.hidden = [h]
for _ in range(horizon-1):
h, _ = self._cell(act[-1], [h])
act_dist.append(self.action(h, dtype))
act.append(act_dist[-1].sample())
self.hidden.append(h)
return act_dist, act
def action(self, features, dtype):
raw_init_std = np.log(np.exp(self._init_std) - 1)
x = features
if self._dist == 'tanh_normal':
# https://www.desmos.com/calculator/rcmcf5jwe7
x = self.get(f'hout', tfkl.Dense, 2 * self._size)(x)
if dtype:
x = tf.cast(x, dtype)
mean, std = tf.split(x, 2, -1)
mean = self._mean_scale * tf.tanh(mean / self._mean_scale)
std = tf.nn.softplus(std + raw_init_std) + self._min_std
dist = tfd.Normal(mean, std)
dist = tfd.TransformedDistribution(dist, tools.TanhBijector())
dist = tfd.Independent(dist, 1)
dist = tools.SampleDist(dist)
elif self._dist == 'normal':
x = self.get(f'hout', tfkl.Dense, 2 * self._size)(x)
if dtype:
x = tf.cast(x, dtype)
mean, std = tf.split(x, 2, -1)
mean = self._mean_scale * tf.tanh(mean / self._mean_scale)
std = tf.nn.softplus(std + raw_init_std) + self._min_std
dist = tfd.Normal(mean, std)
dist = tfd.Independent(dist, 1)
elif self._dist == 'onehot':
x = self.get(f'hout', tfkl.Dense, self._size)(x)
if dtype:
x = tf.cast(x, dtype)
dist = tools.OneHotDist(x, dtype=dtype)
elif self._dist == 'discretized':
x = self.get(f'hout', tfkl.Dense, self._size * self._action_disc)(x)
x = tf.reshape(x, x.shape[:-1] + [self._size, self._action_disc])
if dtype:
x = tf.cast(x, dtype)
dist = tfd.Independent(tools.UnifDiscDist(x, dtype=dtype), 1)
else:
raise NotImplementedError(dist)
return dist
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment