Created
April 29, 2020 13:58
-
-
Save alexlimh/bc3d64dcb12aa8c8a37d3f3ece6e7f88 to your computer and use it in GitHub Desktop.
dreamer.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import argparse | |
import collections | |
import functools | |
import gc | |
import os | |
import pathlib | |
import resource | |
import sys | |
import warnings | |
warnings.filterwarnings('ignore', '.*box bound precision lowered.*') | |
# os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' | |
os.environ['MUJOCO_GL'] = 'egl' | |
import numpy as np | |
import tensorflow as tf | |
from tensorflow.keras.mixed_precision import experimental as prec | |
# tf.get_logger().setLevel('ERROR') | |
from tensorflow_probability import distributions as tfd | |
sys.path.append(str(pathlib.Path(__file__).parent)) | |
import models | |
import tools | |
import wrappers | |
def define_config(): | |
config = tools.AttrDict() | |
# General. | |
config.load_config = False | |
config.ckpt = True | |
config.logdir = pathlib.Path('.') | |
config.ckptdir = pathlib.Path('.') | |
config.traindir = None | |
config.evaldir = None | |
config.seed = 0 | |
config.steps = 1e7 | |
config.eval_every = 1e4 | |
config.log_every = 1e4 | |
config.backup_every = 1e3 | |
config.gpu_growth = True | |
config.precision = 16 | |
config.debug = False | |
config.expl_gifs = False | |
config.compress_dataset = False | |
config.gc = False | |
config.deleps = False | |
# Environment. | |
config.task = 'dmc_walker_walk' | |
config.size = (64, 64) | |
config.parallel = 'none' | |
config.envs = 1 | |
config.action_repeat = 2 | |
config.time_limit = 1000 | |
config.prefill = 5000 | |
config.eval_noise = 0.0 | |
config.clip_rewards = 'identity' | |
config.atari_grayscale = False | |
config.atari_lifes = False | |
config.atari_fire_start = False | |
config.dmlab_path = None | |
# Model. | |
config.dynamics = 'rssm' | |
config.grad_heads = ('image', 'reward') | |
config.deter_size = 200 | |
config.stoch_size = 50 | |
config.stoch_depth = 10 | |
config.dynamics_deep = False | |
config.units = 400 | |
config.reward_layers = 2 | |
config.discount_layers = 3 | |
config.value_layers = 3 | |
config.actor_layers = 4 | |
config.act = 'elu' | |
config.cnn_depth = 32 | |
config.encoder_kernels = (4, 4, 4, 4) | |
config.decoder_kernels = (5, 5, 6, 6) | |
config.decoder_thin = True | |
config.cnn_embed_size = 0 | |
config.kl_scale = 1.0 | |
config.free_nats = 3.0 | |
config.kl_balance = 0.8 | |
config.pred_discount = False | |
config.discount_rare_bonus = 0.0 | |
config.discount_scale = 10.0 | |
config.reward_scale = 2.0 | |
config.weight_decay = 0.0 | |
config.proprio = False | |
# Training. | |
config.batch_size = 50 | |
config.batch_length = 50 | |
config.train_every = 500 | |
config.train_steps = 100 | |
config.pretrain = 100 | |
config.model_lr = 6e-4 | |
config.value_lr = 8e-5 | |
config.actor_lr = 8e-5 | |
config.grad_clip = 100.0 | |
config.opt_eps = 1e-5 | |
config.value_grad_clip = 100.0 | |
config.actor_grad_clip = 100.0 | |
config.dataset_size = 0 | |
config.dropout_embed = 0.0 | |
config.dropout_feat = 0.0 | |
config.dropout_imag = 0.0 | |
config.oversample_ends = False | |
config.slow_value_target = False | |
config.slow_actor_target = False | |
config.slow_target_update = 0 | |
config.slow_target_soft = False | |
config.slow_target_moving_weight = 0.1 | |
# Behavior. | |
config.discount = 0.99 | |
config.discount_lambda = 0.95 | |
config.imag_horizon = 15 | |
config.actor_dist = 'tanh_normal' | |
config.actor_disc = 5 | |
config.actor_entropy = 0.0 | |
config.actor_init_std = 5.0 | |
config.expl = 'additive_gaussian' | |
config.expl_amount = 0.3 | |
config.eval_state_mean = False | |
config.reuse_task_behavior = False | |
config.task_expl_reward_weights = 0.5 | |
# Exploration. | |
config.expl_behavior = 'greedy' | |
config.expl_until = 0 | |
config.disag_target = 'stoch' # embed | |
config.disag_scale = 5.0 | |
config.disag_log = True | |
config.disag_samples = 10 | |
config.disag_models = 10 | |
config.disag_offset = 1 | |
config.disag_layers = 4 | |
config.vim_layers = 4 | |
config.disag_units = 400 | |
config.disag_project = 0 # 100 | |
config.disag_noise_scale = 10.0 | |
config.disag_save_memory = True | |
config.empow_scale = 1.0 | |
config.empow_entropy = 'auto' | |
config.empow_marginal = 'sequences' | |
config.empow_sequences = 5 # 3 | |
config.empow_samples = 5 # 3 | |
config.temperature = 1.0 | |
config.empow_horizon = 1 | |
config.reuse_actor = False | |
return config | |
class Dreamer(tools.Module): | |
def __init__(self, config, logger, dataset): | |
self._config = config | |
self._logger = logger | |
self._float = prec.global_policy().compute_dtype | |
self._should_log = tools.Every(config.log_every) | |
self._should_train = tools.Every(config.train_every) | |
self._should_pretrain = tools.Once() | |
self._should_expl = tools.Until(int(config.expl_until / config.action_repeat)) | |
self._metrics = collections.defaultdict(tf.metrics.Mean) | |
with tf.device('cpu:0'): | |
self._step = tf.Variable(count_steps(config.traindir), dtype=tf.int64) | |
# self._strategy = tf.distribute.MirroredStrategy() | |
# with self._strategy.scope(): | |
# self._dataset = self._strategy.experimental_distribute_dataset(dataset) | |
self._dataset = iter(dataset) | |
self._world_model = models.WorldModel(config) | |
self._behavior = models.ImagBehavior(config, self._world_model) | |
if not config.reuse_task_behavior: | |
behavior = models.ImagBehavior(config, self._world_model) | |
else: | |
behavior = self._behavior | |
self._expl_behavior = dict( | |
greedy=lambda: self._behavior, | |
random=lambda: models.RandomBehavior(config), | |
disag=lambda: models.DisagBehavior(config, self._world_model, behavior), | |
disag_info=lambda: models.DisagInfoBehavior(config, self._world_model, behavior), | |
disag_noise=lambda: models.DisagNoiseBehavior(config, self._world_model, behavior), | |
empow_open=lambda: models.EmpowOpenBehavior(config, self._world_model, behavior), | |
empow_action=lambda: models.EmpowActionBehavior(config, self._world_model, behavior), | |
empow_state=lambda: models.EmpowStateBehavior(config, self._world_model, behavior), | |
empow_step_state=lambda: models.EmpowStepStateBehavior(config, self._world_model, behavior), | |
empow_step_action=lambda: models.EmpowStepActionBehavior(config, self._world_model, behavior), | |
empow_vim=lambda: models.EmpowVIMBehavior(config, self._world_model, behavior), | |
)[config.expl_behavior]() | |
# Train step to initialize variables including optimizer statistics. | |
self.train(next(self._dataset)) | |
def __call__(self, obs, reset, state=None, training=True): | |
step = self._step.numpy().item() | |
if state is not None and reset.any(): | |
mask = tf.cast(1 - reset, self._float)[:, None] | |
state = tf.nest.map_structure(lambda x: x * mask, state) | |
if training and self._should_train(step): | |
# with self._strategy.scope(): | |
steps = ( | |
self._config.pretrain if self._should_pretrain() | |
else self._config.train_steps) | |
for _ in range(steps): | |
self.train(next(self._dataset)) | |
if self._should_log(step): | |
for name, mean in self._metrics.items(): | |
self._logger.scalar(name, float(mean.result())) | |
mean.reset_states() | |
openl = self._world_model.video_pred(next(self._dataset)) | |
self._logger.video('train_openl', openl) | |
self._logger.write(fps=True) | |
action, state = self._policy(obs, state, training) | |
if training: | |
self._step.assign_add(len(reset)) | |
self._logger.step = self._config.action_repeat * self._step.numpy().item() | |
return action, state | |
@tf.function | |
def _policy(self, obs, state, training): | |
if state is None: | |
batch_size = len(obs['image']) | |
latent = self._world_model.dynamics.initial(len(obs['image'])) | |
action = tf.zeros((batch_size, self._config.num_actions), self._float) | |
else: | |
latent, action = state | |
embed = self._world_model.encoder(self._world_model.preprocess(obs)) | |
latent, _ = self._world_model.dynamics.obs_step(latent, action, embed) | |
if self._config.eval_state_mean: | |
latent['stoch'] = latent['mean'] | |
feat = self._world_model.dynamics.get_feat(latent) | |
if not training: | |
action = self._behavior.actor(feat).mode() | |
elif self._should_expl(self._step): | |
action = self._expl_behavior.actor(feat).sample() | |
else: | |
action = self._behavior.actor(feat).sample() | |
action = self._exploration(action, training) | |
state = (latent, action) | |
return action, state | |
def _exploration(self, action, training): | |
amount = self._config.expl_amount if training else self._config.eval_noise | |
if amount == 0: | |
return action | |
amount = tf.cast(amount, self._float) | |
if self._config.expl == 'additive_gaussian': | |
return tf.clip_by_value(tfd.Normal(action, amount).sample(), -1, 1) | |
if self._config.expl == 'epsilon_greedy': | |
probs = amount / self._config.num_actions + (1 - amount) * action | |
return tools.OneHotDist(probs=probs).sample() | |
raise NotImplementedError(self._config.expl) | |
# @tf.function | |
# def train(self, data): | |
# self._strategy.experimental_run_v2(self._train, (data,)) | |
@tf.function | |
def train(self, data): | |
metrics = {} | |
embed, post, feat, mets = self._world_model.train(data) | |
metrics.update(mets) | |
reward = lambda f, s, a: self._world_model.heads['reward'](f).mode() | |
if not self._config.reuse_task_behavior: | |
metrics.update(self._behavior.train(post, reward)) | |
if self._config.expl_behavior != 'greedy': | |
if self._config.reuse_task_behavior: | |
mets = self._expl_behavior.train(post, feat, embed, reward) | |
else: | |
mets = self._expl_behavior.train(post, feat, embed) | |
metrics.update({'expl_' + key: value for key, value in mets.items()}) | |
# if tf.distribute.get_replica_context().replica_id_in_sync_group == 0: | |
# if self._config.pred_discount: | |
# metrics.update(self._episode_end_precision_recall(data, feat)) | |
for name, value in metrics.items(): | |
self._metrics[name].update_state(value) | |
# def _episode_end_precision_recall(self, data, feat): | |
# pred = self._world_model.heads['discount'](feat).mean() | |
# ppos, pneg = pred <= 0.5, pred > 0.5 | |
# dpos, dneg = data['discount'] <= 0.5, data['discount'] > 0.5 | |
# tp = tf.reduce_sum(tf.cast(tf.logical_and(ppos, dpos), self._float)) | |
# tn = tf.reduce_sum(tf.cast(tf.logical_and(pneg, dneg), self._float)) | |
# fp = tf.reduce_sum(tf.cast(tf.logical_and(ppos, dneg), self._float)) | |
# fn = tf.reduce_sum(tf.cast(tf.logical_and(pneg, dpos), self._float)) | |
# tpr, tnr, ppv = tp / (tp + fn), tn / (tn + fp), tp / (tp + fp) | |
# ba, f1 = (tpr + tnr) / 2, 2 * ppv * tpr / (ppv + tpr) | |
# any_ = tf.cast(tf.reduce_any(dpos), self._float) | |
# metrics = dict( | |
# ee_tp=tp, ee_tn=tn, ee_fp=fp, ee_fn=fn, ee_tpr=tpr, ee_tnr=tnr, | |
# ee_ppv=ppv, ee_ba=ba, ee_f1=f1, ee_any=any_) | |
# return { | |
# k: tf.where(tf.math.is_nan(v), tf.zeros_like(v), v) | |
# for k, v in metrics.items()} | |
def count_steps(directory): | |
return sum(int(str(n).split('-')[-1][:-4]) - 1 for n in directory.glob('*.npz')) | |
def make_dataset(episodes, config): | |
example = episodes[next(iter(episodes.keys()))] | |
example = tools.decompress_episode(example) | |
types = {k: v.dtype for k, v in example.items()} | |
shapes = {k: (None,) + v.shape[1:] for k, v in example.items()} | |
generator = lambda: tools.sample_episodes( | |
episodes, config.batch_length, config.oversample_ends) | |
dataset = tf.data.Dataset.from_generator(generator, types, shapes) | |
dataset = dataset.batch(config.batch_size, drop_remainder=True) | |
dataset = dataset.prefetch(10) | |
return dataset | |
def make_env(config, logger, mode, train_eps, eval_eps): | |
suite, task = config.task.split('_', 1) | |
if suite == 'dmc' or suite == 'dmctoy': | |
env = wrappers.DeepMindControl(task, config.action_repeat, config.size) | |
env = wrappers.NormalizeActions(env) | |
elif suite == 'atari': | |
env = wrappers.Atari( | |
task, config.action_repeat, config.size, | |
grayscale=config.atari_grayscale, sticky_actions=True, | |
life_done=config.atari_lifes and (mode == 'train')) | |
env = wrappers.OneHotAction(env) | |
elif suite == 'dmlab': | |
args = task, config.action_repeat, config.size | |
env = wrappers.DeepMindLab(*args, path=config.dmlab_path) | |
env = wrappers.OneHotAction(env) | |
elif suite == 'mc': | |
args = task, mode, config.size, config.action_repeat | |
env = wrappers.Minecraft(*args) | |
env = wrappers.OneHotAction(env) | |
else: | |
raise NotImplementedError(suite) | |
env = wrappers.TimeLimit(env, config.time_limit) | |
callbacks = [functools.partial( | |
process_episode, config, logger, mode, train_eps, eval_eps)] | |
env = wrappers.Collect(env, callbacks, config.precision) | |
env = wrappers.RewardObs(env) | |
return env | |
def process_episode(config, logger, mode, train_eps, eval_eps, episode): | |
directory = dict(train=config.traindir, eval=config.evaldir)[mode] | |
cache = dict(train=train_eps, eval=eval_eps)[mode] | |
filename = tools.save_episodes(directory, [episode])[0] | |
length = len(episode['reward']) - 1 | |
score = float(episode['reward'].astype(np.float32).sum()) | |
video = episode['image'] | |
if mode == 'eval': | |
if config.deleps: | |
for key, value in list(cache.items()): | |
del cache[key] | |
del value | |
cache.clear() | |
if mode == 'train' and config.dataset_size: | |
total = 0 | |
for key, ep in reversed(sorted(cache.items(), key=lambda x: x[0])): | |
if total <= config.dataset_size - length: | |
total += len(ep['reward']) - 1 | |
else: | |
del cache[key] | |
if config.deleps: | |
del ep | |
logger.scalar('dataset_size', total + length) | |
if config.compress_dataset: | |
episode = tools.compress_episode(episode, (np.uint8,)) | |
cache[str(filename)] = episode | |
print(f'{mode.title()} episode has {length} steps and return {score:.1f}.') | |
logger.scalar(f'{mode}_return', score) | |
logger.scalar(f'{mode}_length', length) | |
logger.scalar(f'{mode}_episodes', len(cache)) | |
if mode == 'eval' or config.expl_gifs: | |
logger.video(f'{mode}_policy', video[None]) | |
logger.write() | |
def main(config): | |
if config.load_config: | |
from configs import default_configs | |
suite, _ = config.task.split('_', 1) | |
default = default_configs[suite] | |
config = vars(config) | |
for key, value in default.items(): | |
config[key] = value | |
config = tools.AttrDict(config) | |
print(config) | |
config.traindir = config.traindir or config.logdir / 'train_eps' | |
config.evaldir = config.evaldir or config.logdir / 'eval_eps' | |
config.steps //= config.action_repeat | |
config.eval_every //= config.action_repeat | |
config.log_every //= config.action_repeat | |
config.backup_every //= config.action_repeat | |
config.time_limit //= config.action_repeat | |
config.act = getattr(tf.nn, config.act) | |
if config.empow_entropy == 'auto': | |
options = dict( | |
tanh_normal='sample', onehot='sample', discretized='analytic') | |
config.empow_entropy = options[config.actor_dist] | |
if config.debug: | |
tf.config.experimental_run_functions_eagerly(True) | |
if config.gpu_growth: | |
for gpu in tf.config.experimental.list_physical_devices('GPU'): | |
tf.config.experimental.set_memory_growth(gpu, True) | |
assert config.precision in (16, 32), config.precision | |
if config.precision == 16: | |
prec.set_policy(prec.Policy('mixed_float16')) | |
print('Logdir', config.logdir) | |
config.logdir.mkdir(parents=True, exist_ok=True) | |
step = count_steps(config.traindir) | |
logger = tools.Logger(config.logdir, config.action_repeat * step) | |
print('Create envs.') | |
train_eps = tools.load_episodes(config.traindir, limit=config.dataset_size) | |
eval_eps = tools.load_episodes(config.evaldir, limit=1) | |
make1 = lambda mode: make_env(config, logger, mode, train_eps, eval_eps) | |
make2 = lambda mode: wrappers.Async(lambda: make1(mode), config.parallel) | |
train_envs = [make2('train') for _ in range(config.envs)] | |
eval_envs = [make2('eval') for _ in range(config.envs)] | |
acts = train_envs[0].action_space | |
config.num_actions = acts.n if hasattr(acts, 'n') else acts.shape[0] | |
prefill = max(0, config.prefill - count_steps(config.traindir)) | |
print(f'Prefill dataset ({prefill} steps).') | |
random_agent = lambda o, d, s: ([acts.sample() for _ in d], s) | |
tools.simulate(random_agent, train_envs, prefill) | |
tools.simulate(random_agent, eval_envs, episodes=1) | |
logger.step = config.action_repeat * count_steps(config.traindir) | |
print('Simulate agent.') | |
train_dataset = make_dataset(train_eps, config) | |
eval_dataset = iter(make_dataset(eval_eps, config)) | |
agent = Dreamer(config, logger, train_dataset) | |
if (config.logdir / 'variables.pkl').exists(): | |
try: | |
agent.load(config.logdir / 'variables.pkl') | |
except: | |
agent.load(config.logdir / 'variables-backup.pkl') | |
agent._should_pretrain._once = False | |
state = None | |
should_backup = tools.Every(config.backup_every) | |
while agent._step.numpy().item() < config.steps: | |
print('Process statistics.') | |
if config.gc: | |
gc.collect() | |
logger.scalar('python_objects', len(gc.get_objects())) | |
with open('/proc/self/statm') as f: | |
ram = int(f.read().split()[1]) * resource.getpagesize() / 1024 / 1024 / 1024 | |
logger.scalar('ram', ram) | |
logger.write() | |
print('Start evaluation.') | |
video_pred = agent._world_model.video_pred(next(eval_dataset)) | |
logger.video('eval_openl', video_pred) | |
eval_policy = functools.partial(agent, training=False) | |
tools.simulate(eval_policy, eval_envs, episodes=1) | |
print('Start training.') | |
state = tools.simulate(agent, train_envs, config.eval_every, state=state) | |
if config.ckpt: | |
agent.save(config.ckptdir / 'variables.pkl') | |
agent.save(config.logdir / 'variables.pkl') | |
if should_backup(agent._step.numpy().item()): | |
if config.ckpt: | |
agent.save(config.ckptdir / 'variables-backup.pkl') | |
agent.save(config.logdir / 'variables-backup.pkl') | |
for env in train_envs + eval_envs: | |
env.close() | |
if __name__ == '__main__': | |
print(os.getpid()) | |
try: | |
import colored_traceback | |
colored_traceback.add_hook() | |
except ImportError: | |
pass | |
parser = argparse.ArgumentParser() | |
for key, value in define_config().items(): | |
parser.add_argument(f'--{key}', type=tools.args_type(value), default=value) | |
main(parser.parse_args()) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class EmpowStepActionBehavior(tools.Module): # I(S';A|S), action space | |
def __init__(self, config, world_model, behavior): | |
self._config = config | |
self._world_model = world_model | |
self._behavior = behavior | |
self.actor = self._behavior.actor | |
self._tau = config.temperature # temperature | |
self._plan_ss = networks.ActionHead( | |
config.num_actions, config.actor_layers, config.units, config.act, | |
config.actor_dist, config.actor_init_std, config.actor_disc) # p(a_0|s, s') | |
if config.reuse_actor: | |
self._source = self._behavior.actor | |
else: | |
self._source = networks.ActionHead( | |
config.num_actions, config.actor_layers, config.units, config.act, | |
config.actor_dist, config.actor_init_std, config.actor_disc) | |
self._offset = networks.DenseHead( | |
[], config.actor_layers, config.units, config.act) | |
self._plan_opt = tools.Adam( | |
'plan', config.actor_lr, config.opt_eps, config.grad_clip, | |
config.weight_decay) | |
self._source_opt = tools.Adam( | |
'source', config.actor_lr, config.opt_eps, config.grad_clip, | |
config.weight_decay) # for both source and offset | |
def train(self, start, feat, embed, task_reward=None): | |
metrics = {} | |
metrics.update(self._behavior.train(start, lambda f, s, a: self._intrinsic_reward(f, s, a, task_reward))) | |
metrics.update(self._train_step(start)) | |
return metrics | |
def _intrinsic_reward(self, feat, state, action, task_reward=None): | |
reward = self._tau * self._offset(feat).mode() | |
if task_reward: | |
t_reward = task_reward(feat, 0, 0) | |
reward = dict(expl=reward, task=t_reward) | |
return reward | |
def _imagine(self, start): # don't use this for self._behavior.train | |
flatten = lambda x: tf.reshape(x, [-1] + list(x.shape[2:])) | |
start = {k: flatten(v) for k, v in start.items()} | |
def step(prev, _): | |
state, _, _ = prev | |
feat = self._world_model.dynamics.get_feat(state) | |
if self._config.dropout_imag: | |
feat = tf.nn.dropout(feat, self._config.dropout_imag) | |
action = self._source(tf.stop_gradient(feat)).sample() # source instead of actor | |
state = self._world_model.dynamics.img_step(state, action) | |
return state, feat, action | |
range_ = tf.range(self._config.imag_horizon) | |
feat = 0 * self._world_model.dynamics.get_feat(start) | |
action = 0 * self._source(feat).mean() | |
states, feats, actions = tools.static_scan(step, range_, (start, feat, action)) | |
return feats, states, actions | |
def _train_step(self, start): | |
feat, _, action = self._imagine(start) | |
ss_feat = tf.concat([feat[:-1], feat[1:]], axis=-1) # time-1, batch, features | |
action = tf.stop_gradient(tf.cast(action, tf.float32)) | |
metrics = {} | |
with tf.GradientTape() as tape: | |
plan_pred = self._plan_ss(ss_feat, dtype=tf.float32) | |
log_plan_prob = plan_pred.log_prob(action[:-1]) | |
loss = -tf.reduce_mean(log_plan_prob) | |
norm = self._plan_opt(tape, loss, [self._plan_ss]) | |
metrics['plan_grad_norm'] = norm | |
metrics['plan_loss'] = loss | |
if self._config.precision == 16: | |
metrics['plan_dynamics_loss_scale'] = self._plan_opt.loss_scale | |
with tf.GradientTape() as tape: | |
offset = self._offset(feat[:-1], dtype=tf.float32).mode() | |
action_pred = self._source(feat[:-1], dtype=tf.float32) | |
log_action_prob = action_pred.log_prob(action[:-1]) | |
r_sa = log_action_prob + offset | |
loss = tf.reduce_mean(tf.square(1./self._tau * tf.stop_gradient(log_plan_prob) - r_sa)) | |
norm = self._source_opt(tape, loss, [self._source, self._offset]) | |
metrics['source_grad_norm'] = norm | |
metrics['source_loss'] = loss | |
if self._config.precision == 16: | |
metrics['source_loss_scale'] = self._source_opt.loss_scale | |
return metrics | |
class EmpowVIMBehavior(tools.Module): # I(S^T;A|S), action space, n-step | |
def __init__(self, config, world_model, behavior): | |
self._config = config | |
self._world_model = world_model | |
self._behavior = behavior | |
self.actor = self._behavior.actor | |
self._tau = config.temperature # temperature | |
self._plan = networks.ActionHead( | |
config.num_actions, config.actor_layers, config.units, config.act, | |
config.actor_dist, config.actor_init_std, config.actor_disc) # p(a_1:t|s, s') | |
if config.reuse_actor: | |
self._source = self._behavior.actor | |
else: | |
self._source = networks.ActionHead( | |
config.num_actions, config.actor_layers, config.units, config.act, | |
config.actor_dist, config.actor_init_std, config.actor_disc) # p(a_1:t|s) | |
self._offset = networks.DenseHead( | |
[], config.actor_layers, config.units, config.act) | |
self._plan_opt = tools.Adam( | |
'plan', config.actor_lr, config.opt_eps, config.grad_clip, | |
config.weight_decay) | |
self._source_opt = tools.Adam( | |
'source', config.actor_lr, config.opt_eps, config.grad_clip, | |
config.weight_decay) # for both source and offset | |
def train(self, start, feat, embed, task_reward=None): | |
metrics = {} | |
metrics.update(self._behavior.train(start, lambda f, s, a: self._intrinsic_reward(f, s, a, task_reward))) | |
metrics.update(self._train_vim(start)) | |
return metrics | |
def _intrinsic_reward(self, feat, state, action, task_reward=None): | |
reward = self._tau * self._offset(feat).mode() | |
if task_reward: | |
t_reward = task_reward(feat, 0, 0) | |
reward = dict(expl=reward, task=t_reward) | |
return reward | |
def _imagine(self, start, actions): | |
def step(prev, action): | |
state, _ = prev | |
feat = self._world_model.dynamics.get_feat(state) | |
if self._config.dropout_imag: | |
feat = tf.nn.dropout(feat, self._config.dropout_imag) | |
action = tf.cast(action, feat.dtype) | |
state = self._world_model.dynamics.img_step(state, action) | |
return state, feat | |
feat = self._world_model.dynamics.get_feat(start) | |
states, feats = tools.static_scan(step, actions, (start, feat)) | |
return feats[-1] | |
def _train_vim(self, start): | |
flatten = lambda x: tf.reshape(x, [-1] + list(x.shape[2:])) | |
start = {k: flatten(v) for k, v in start.items()} | |
feat = self._world_model.dynamics.get_feat(start) | |
metrics = {} | |
# Get the plan | |
action_pred, action_list = self._source(feat, dtype=tf.float32, horizon=self._config.empow_horizon) | |
# Execute it in the model and retreive the last state | |
last_feat = self._imagine(start, tf.stack(action_list, axis=0)) | |
with tf.GradientTape() as tape: | |
plan_pred, _ = self._plan(tf.concat([feat, last_feat], axis=-1), dtype=tf.float32, horizon=self._config.empow_horizon) | |
log_plan_prob = [pp.log_prob(tf.stop_gradient(a)) for a, pp in zip(action_list, plan_pred)] | |
log_plan_prob = tf.reduce_mean(tf.stack(log_plan_prob, axis=0), axis=0) | |
loss = -tf.reduce_mean(log_plan_prob) | |
norm = self._plan_opt(tape, loss, [self._plan]) | |
metrics['plan_grad_norm'] = norm | |
metrics['plan_loss'] = loss | |
if self._config.precision == 16: | |
metrics['plan_loss_scale'] = self._plan_opt.loss_scale | |
with tf.GradientTape() as tape: | |
offset = self._offset(feat, dtype=tf.float32).mode() | |
log_action_prob = [ap.log_prob(tf.stop_gradient(a)) for a, ap in zip(action_list, action_pred)] | |
log_action_prob = tf.reduce_mean(tf.stack(log_action_prob, axis=0), axis=0) | |
r_sa = log_action_prob + offset | |
loss = tf.reduce_mean(tf.square(1./self._tau * tf.stop_gradient(log_plan_prob) - r_sa)) | |
norm = self._source_opt(tape, loss, [self._source, self._offset]) | |
metrics['source_grad_norm'] = norm | |
metrics['source_loss'] = loss | |
if self._config.precision == 16: | |
metrics['source_loss_scale'] = self._source_opt.loss_scale | |
return metrics |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class ActionHead(tools.Module): | |
def __init__( | |
self, size, layers, units, act=tf.nn.elu, dist='tanh_normal', | |
init_std=5.0, min_std=1e-4, mean_scale=5.0, action_disc=5): | |
self._size = size | |
self._layers = layers | |
self._units = units | |
self._dist = dist | |
self._act = act | |
self._min_std = min_std | |
self._init_std = init_std | |
self._mean_scale = mean_scale | |
self._action_disc = action_disc | |
self._cell = tfkl.GRUCell(self._units) | |
def __call__(self, features, dtype=None, horizon=0): | |
x = features | |
for index in range(self._layers): | |
x = self.get(f'h{index}', tfkl.Dense, self._units, self._act)(x) | |
h = x # init hidden state | |
act_dist = [self.action(h, dtype)] | |
if horizon == 0: | |
self.hidden = h | |
return act_dist[0] | |
act = [act_dist[-1].sample()] | |
self.hidden = [h] | |
for _ in range(horizon-1): | |
h, _ = self._cell(act[-1], [h]) | |
act_dist.append(self.action(h, dtype)) | |
act.append(act_dist[-1].sample()) | |
self.hidden.append(h) | |
return act_dist, act | |
def action(self, features, dtype): | |
raw_init_std = np.log(np.exp(self._init_std) - 1) | |
x = features | |
if self._dist == 'tanh_normal': | |
# https://www.desmos.com/calculator/rcmcf5jwe7 | |
x = self.get(f'hout', tfkl.Dense, 2 * self._size)(x) | |
if dtype: | |
x = tf.cast(x, dtype) | |
mean, std = tf.split(x, 2, -1) | |
mean = self._mean_scale * tf.tanh(mean / self._mean_scale) | |
std = tf.nn.softplus(std + raw_init_std) + self._min_std | |
dist = tfd.Normal(mean, std) | |
dist = tfd.TransformedDistribution(dist, tools.TanhBijector()) | |
dist = tfd.Independent(dist, 1) | |
dist = tools.SampleDist(dist) | |
elif self._dist == 'normal': | |
x = self.get(f'hout', tfkl.Dense, 2 * self._size)(x) | |
if dtype: | |
x = tf.cast(x, dtype) | |
mean, std = tf.split(x, 2, -1) | |
mean = self._mean_scale * tf.tanh(mean / self._mean_scale) | |
std = tf.nn.softplus(std + raw_init_std) + self._min_std | |
dist = tfd.Normal(mean, std) | |
dist = tfd.Independent(dist, 1) | |
elif self._dist == 'onehot': | |
x = self.get(f'hout', tfkl.Dense, self._size)(x) | |
if dtype: | |
x = tf.cast(x, dtype) | |
dist = tools.OneHotDist(x, dtype=dtype) | |
elif self._dist == 'discretized': | |
x = self.get(f'hout', tfkl.Dense, self._size * self._action_disc)(x) | |
x = tf.reshape(x, x.shape[:-1] + [self._size, self._action_disc]) | |
if dtype: | |
x = tf.cast(x, dtype) | |
dist = tfd.Independent(tools.UnifDiscDist(x, dtype=dtype), 1) | |
else: | |
raise NotImplementedError(dist) | |
return dist |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment