Created
April 27, 2020 19:16
-
-
Save alexlimh/a362b22025d2708e8a8cc02f3aaf6e73 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import argparse | |
import collections | |
import functools | |
import gc | |
import os | |
import pathlib | |
import resource | |
import sys | |
import warnings | |
warnings.filterwarnings('ignore', '.*box bound precision lowered.*') | |
# os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' | |
os.environ['MUJOCO_GL'] = 'egl' | |
import numpy as np | |
import tensorflow as tf | |
from tensorflow.keras.mixed_precision import experimental as prec | |
# tf.get_logger().setLevel('ERROR') | |
from tensorflow_probability import distributions as tfd | |
sys.path.append(str(pathlib.Path(__file__).parent)) | |
import models | |
import tools | |
import wrappers | |
def define_config(): | |
config = tools.AttrDict() | |
# General. | |
config.load_config = False | |
config.ckpt = True | |
config.logdir = pathlib.Path('.') | |
config.ckptdir = pathlib.Path('.') | |
config.traindir = None | |
config.evaldir = None | |
config.seed = 0 | |
config.steps = 1e7 | |
config.eval_every = 1e4 | |
config.log_every = 1e4 | |
config.backup_every = 1e3 | |
config.gpu_growth = True | |
config.precision = 16 | |
config.debug = False | |
config.expl_gifs = False | |
config.compress_dataset = False | |
config.gc = False | |
config.deleps = False | |
# Environment. | |
config.task = 'dmc_walker_walk' | |
config.size = (64, 64) | |
config.parallel = 'none' | |
config.envs = 1 | |
config.action_repeat = 2 | |
config.time_limit = 1000 | |
config.prefill = 5000 | |
config.eval_noise = 0.0 | |
config.clip_rewards = 'identity' | |
config.atari_grayscale = False | |
config.atari_lifes = False | |
config.atari_fire_start = False | |
config.dmlab_path = None | |
# Model. | |
config.dynamics = 'rssm' | |
config.grad_heads = ('image', 'reward') | |
config.deter_size = 200 | |
config.stoch_size = 50 | |
config.stoch_depth = 10 | |
config.dynamics_deep = False | |
config.units = 400 | |
config.reward_layers = 2 | |
config.discount_layers = 3 | |
config.value_layers = 3 | |
config.actor_layers = 4 | |
config.act = 'elu' | |
config.cnn_depth = 32 | |
config.encoder_kernels = (4, 4, 4, 4) | |
config.decoder_kernels = (5, 5, 6, 6) | |
config.decoder_thin = True | |
config.cnn_embed_size = 0 | |
config.kl_scale = 1.0 | |
config.free_nats = 3.0 | |
config.kl_balance = 0.8 | |
config.pred_discount = False | |
config.discount_rare_bonus = 0.0 | |
config.discount_scale = 10.0 | |
config.reward_scale = 2.0 | |
config.weight_decay = 0.0 | |
config.proprio = False | |
# Training. | |
config.batch_size = 50 | |
config.batch_length = 50 | |
config.train_every = 500 | |
config.train_steps = 100 | |
config.pretrain = 100 | |
config.model_lr = 6e-4 | |
config.value_lr = 8e-5 | |
config.actor_lr = 8e-5 | |
config.grad_clip = 100.0 | |
config.opt_eps = 1e-5 | |
config.value_grad_clip = 100.0 | |
config.actor_grad_clip = 100.0 | |
config.dataset_size = 0 | |
config.dropout_embed = 0.0 | |
config.dropout_feat = 0.0 | |
config.dropout_imag = 0.0 | |
config.oversample_ends = False | |
config.slow_value_target = False | |
config.slow_actor_target = False | |
config.slow_target_update = 0 | |
config.slow_target_soft = False | |
config.slow_target_moving_weight = 0.1 | |
# Behavior. | |
config.discount = 0.99 | |
config.discount_lambda = 0.95 | |
config.imag_horizon = 15 | |
config.actor_dist = 'tanh_normal' | |
config.actor_disc = 5 | |
config.actor_entropy = 0.0 | |
config.actor_init_std = 5.0 | |
config.expl = 'additive_gaussian' | |
config.expl_amount = 0.3 | |
config.eval_state_mean = False | |
config.reuse_task_behavior = False | |
config.task_expl_reward_weights = 0.5 | |
# Exploration. | |
config.expl_behavior = 'greedy' | |
config.expl_until = 0 | |
config.disag_target = 'stoch' # embed | |
config.disag_scale = 5.0 | |
config.disag_log = True | |
config.disag_samples = 10 | |
config.disag_models = 10 | |
config.disag_offset = 1 | |
config.disag_layers = 4 | |
config.vim_layers = 4 | |
config.disag_units = 400 | |
config.disag_project = 0 # 100 | |
config.disag_noise_scale = 10.0 | |
config.disag_save_memory = True | |
config.empow_scale = 1.0 | |
config.empow_entropy = 'auto' | |
config.empow_marginal = 'sequences' | |
config.empow_sequences = 5 # 3 | |
config.empow_samples = 5 # 3 | |
config.temperature = 1.0 | |
config.vim_horizon = 5 | |
return config | |
class Dreamer(tools.Module): | |
def __init__(self, config, logger, dataset): | |
self._config = config | |
self._logger = logger | |
self._float = prec.global_policy().compute_dtype | |
self._should_log = tools.Every(config.log_every) | |
self._should_train = tools.Every(config.train_every) | |
self._should_pretrain = tools.Once() | |
self._should_expl = tools.Until(int(config.expl_until / config.action_repeat)) | |
self._metrics = collections.defaultdict(tf.metrics.Mean) | |
with tf.device('cpu:0'): | |
self._step = tf.Variable(count_steps(config.traindir), dtype=tf.int64) | |
# self._strategy = tf.distribute.MirroredStrategy() | |
# with self._strategy.scope(): | |
# self._dataset = self._strategy.experimental_distribute_dataset(dataset) | |
self._dataset = iter(dataset) | |
self._world_model = models.WorldModel(config) | |
self._behavior = models.ImagBehavior(config, self._world_model) | |
if not config.reuse_task_behavior: | |
behavior = models.ImagBehavior(config, self._world_model) | |
else: | |
behavior = self._behavior | |
self._expl_behavior = dict( | |
greedy=lambda: self._behavior, | |
random=lambda: models.RandomBehavior(config), | |
disag=lambda: models.DisagBehavior(config, self._world_model, behavior), | |
disag_info=lambda: models.DisagInfoBehavior(config, self._world_model, behavior), | |
disag_noise=lambda: models.DisagNoiseBehavior(config, self._world_model, behavior), | |
empow_open=lambda: models.EmpowOpenBehavior(config, self._world_model, behavior), | |
empow_action=lambda: models.EmpowActionBehavior(config, self._world_model, behavior), | |
empow_state=lambda: models.EmpowStateBehavior(config, self._world_model, behavior), | |
empow_step=lambda: models.EmpowStepBehavior(config, self._world_model, behavior), | |
empow_vim=lambda: models.EmpowVIMBehavior(config, self._world_model, behavior), | |
)[config.expl_behavior]() | |
# Train step to initialize variables including optimizer statistics. | |
self.train(next(self._dataset)) | |
def __call__(self, obs, reset, state=None, training=True): | |
step = self._step.numpy().item() | |
if state is not None and reset.any(): | |
mask = tf.cast(1 - reset, self._float)[:, None] | |
state = tf.nest.map_structure(lambda x: x * mask, state) | |
if training and self._should_train(step): | |
# with self._strategy.scope(): | |
steps = ( | |
self._config.pretrain if self._should_pretrain() | |
else self._config.train_steps) | |
for _ in range(steps): | |
self.train(next(self._dataset)) | |
if self._should_log(step): | |
for name, mean in self._metrics.items(): | |
self._logger.scalar(name, float(mean.result())) | |
mean.reset_states() | |
openl = self._world_model.video_pred(next(self._dataset)) | |
self._logger.video('train_openl', openl) | |
self._logger.write(fps=True) | |
action, state = self._policy(obs, state, training) | |
if training: | |
self._step.assign_add(len(reset)) | |
self._logger.step = self._config.action_repeat * self._step.numpy().item() | |
return action, state | |
@tf.function | |
def _policy(self, obs, state, training): | |
if state is None: | |
batch_size = len(obs['image']) | |
latent = self._world_model.dynamics.initial(len(obs['image'])) | |
action = tf.zeros((batch_size, self._config.num_actions), self._float) | |
else: | |
latent, action = state | |
embed = self._world_model.encoder(self._world_model.preprocess(obs)) | |
latent, _ = self._world_model.dynamics.obs_step(latent, action, embed) | |
if self._config.eval_state_mean: | |
latent['stoch'] = latent['mean'] | |
feat = self._world_model.dynamics.get_feat(latent) | |
if not training: | |
action = self._behavior.actor(feat).mode() | |
elif self._should_expl(self._step): | |
action = self._expl_behavior.actor(feat).sample() | |
else: | |
action = self._behavior.actor(feat).sample() | |
action = self._exploration(action, training) | |
state = (latent, action) | |
return action, state | |
def _exploration(self, action, training): | |
amount = self._config.expl_amount if training else self._config.eval_noise | |
if amount == 0: | |
return action | |
amount = tf.cast(amount, self._float) | |
if self._config.expl == 'additive_gaussian': | |
return tf.clip_by_value(tfd.Normal(action, amount).sample(), -1, 1) | |
if self._config.expl == 'epsilon_greedy': | |
probs = amount / self._config.num_actions + (1 - amount) * action | |
return tools.OneHotDist(probs=probs).sample() | |
raise NotImplementedError(self._config.expl) | |
# @tf.function | |
# def train(self, data): | |
# self._strategy.experimental_run_v2(self._train, (data,)) | |
@tf.function | |
def train(self, data): | |
metrics = {} | |
embed, post, feat, mets = self._world_model.train(data) | |
metrics.update(mets) | |
reward = lambda f, s, a: self._world_model.heads['reward'](f).mode() | |
if not self._config.reuse_task_behavior: | |
metrics.update(self._behavior.train(post, reward)) | |
if self._config.expl_behavior != 'greedy': | |
if self._config.reuse_task_behavior: | |
mets = self._expl_behavior.train(post, feat, embed, reward) | |
else: | |
mets = self._expl_behavior.train(post, feat, embed) | |
metrics.update({'expl_' + key: value for key, value in mets.items()}) | |
# if tf.distribute.get_replica_context().replica_id_in_sync_group == 0: | |
# if self._config.pred_discount: | |
# metrics.update(self._episode_end_precision_recall(data, feat)) | |
for name, value in metrics.items(): | |
self._metrics[name].update_state(value) | |
# def _episode_end_precision_recall(self, data, feat): | |
# pred = self._world_model.heads['discount'](feat).mean() | |
# ppos, pneg = pred <= 0.5, pred > 0.5 | |
# dpos, dneg = data['discount'] <= 0.5, data['discount'] > 0.5 | |
# tp = tf.reduce_sum(tf.cast(tf.logical_and(ppos, dpos), self._float)) | |
# tn = tf.reduce_sum(tf.cast(tf.logical_and(pneg, dneg), self._float)) | |
# fp = tf.reduce_sum(tf.cast(tf.logical_and(ppos, dneg), self._float)) | |
# fn = tf.reduce_sum(tf.cast(tf.logical_and(pneg, dpos), self._float)) | |
# tpr, tnr, ppv = tp / (tp + fn), tn / (tn + fp), tp / (tp + fp) | |
# ba, f1 = (tpr + tnr) / 2, 2 * ppv * tpr / (ppv + tpr) | |
# any_ = tf.cast(tf.reduce_any(dpos), self._float) | |
# metrics = dict( | |
# ee_tp=tp, ee_tn=tn, ee_fp=fp, ee_fn=fn, ee_tpr=tpr, ee_tnr=tnr, | |
# ee_ppv=ppv, ee_ba=ba, ee_f1=f1, ee_any=any_) | |
# return { | |
# k: tf.where(tf.math.is_nan(v), tf.zeros_like(v), v) | |
# for k, v in metrics.items()} | |
def count_steps(directory): | |
return sum(int(str(n).split('-')[-1][:-4]) - 1 for n in directory.glob('*.npz')) | |
def make_dataset(episodes, config): | |
example = episodes[next(iter(episodes.keys()))] | |
example = tools.decompress_episode(example) | |
types = {k: v.dtype for k, v in example.items()} | |
shapes = {k: (None,) + v.shape[1:] for k, v in example.items()} | |
generator = lambda: tools.sample_episodes( | |
episodes, config.batch_length, config.oversample_ends) | |
dataset = tf.data.Dataset.from_generator(generator, types, shapes) | |
dataset = dataset.batch(config.batch_size, drop_remainder=True) | |
dataset = dataset.prefetch(10) | |
return dataset | |
def make_env(config, logger, mode, train_eps, eval_eps): | |
suite, task = config.task.split('_', 1) | |
if suite == 'dmc' or suite == 'dmctoy': | |
env = wrappers.DeepMindControl(task, config.action_repeat, config.size) | |
env = wrappers.NormalizeActions(env) | |
elif suite == 'atari': | |
env = wrappers.Atari( | |
task, config.action_repeat, config.size, | |
grayscale=config.atari_grayscale, sticky_actions=True, | |
life_done=config.atari_lifes and (mode == 'train')) | |
env = wrappers.OneHotAction(env) | |
elif suite == 'dmlab': | |
args = task, config.action_repeat, config.size | |
env = wrappers.DeepMindLab(*args, path=config.dmlab_path) | |
env = wrappers.OneHotAction(env) | |
elif suite == 'mc': | |
args = task, mode, config.size, config.action_repeat | |
env = wrappers.Minecraft(*args) | |
env = wrappers.OneHotAction(env) | |
else: | |
raise NotImplementedError(suite) | |
env = wrappers.TimeLimit(env, config.time_limit) | |
callbacks = [functools.partial( | |
process_episode, config, logger, mode, train_eps, eval_eps)] | |
env = wrappers.Collect(env, callbacks, config.precision) | |
env = wrappers.RewardObs(env) | |
return env | |
def process_episode(config, logger, mode, train_eps, eval_eps, episode): | |
directory = dict(train=config.traindir, eval=config.evaldir)[mode] | |
cache = dict(train=train_eps, eval=eval_eps)[mode] | |
filename = tools.save_episodes(directory, [episode])[0] | |
length = len(episode['reward']) - 1 | |
score = float(episode['reward'].astype(np.float32).sum()) | |
video = episode['image'] | |
if mode == 'eval': | |
if config.deleps: | |
for key, value in list(cache.items()): | |
del cache[key] | |
del value | |
cache.clear() | |
if mode == 'train' and config.dataset_size: | |
total = 0 | |
for key, ep in reversed(sorted(cache.items(), key=lambda x: x[0])): | |
if total <= config.dataset_size - length: | |
total += len(ep['reward']) - 1 | |
else: | |
del cache[key] | |
if config.deleps: | |
del ep | |
logger.scalar('dataset_size', total + length) | |
if config.compress_dataset: | |
episode = tools.compress_episode(episode, (np.uint8,)) | |
cache[str(filename)] = episode | |
print(f'{mode.title()} episode has {length} steps and return {score:.1f}.') | |
logger.scalar(f'{mode}_return', score) | |
logger.scalar(f'{mode}_length', length) | |
logger.scalar(f'{mode}_episodes', len(cache)) | |
if mode == 'eval' or config.expl_gifs: | |
logger.video(f'{mode}_policy', video[None]) | |
logger.write() | |
def main(config): | |
if config.load_config: | |
from configs import default_configs | |
suite, _ = config.task.split('_', 1) | |
default = default_configs[suite] | |
config = vars(config) | |
for key, value in default.items(): | |
config[key] = value | |
config = tools.AttrDict(config) | |
print(config) | |
config.traindir = config.traindir or config.logdir / 'train_eps' | |
config.evaldir = config.evaldir or config.logdir / 'eval_eps' | |
config.steps //= config.action_repeat | |
config.eval_every //= config.action_repeat | |
config.log_every //= config.action_repeat | |
config.backup_every //= config.action_repeat | |
config.time_limit //= config.action_repeat | |
config.act = getattr(tf.nn, config.act) | |
if config.empow_entropy == 'auto': | |
options = dict( | |
tanh_normal='sample', onehot='sample', discretized='analytic') | |
config.empow_entropy = options[config.actor_dist] | |
if config.debug: | |
tf.config.experimental_run_functions_eagerly(True) | |
if config.gpu_growth: | |
for gpu in tf.config.experimental.list_physical_devices('GPU'): | |
tf.config.experimental.set_memory_growth(gpu, True) | |
assert config.precision in (16, 32), config.precision | |
if config.precision == 16: | |
prec.set_policy(prec.Policy('mixed_float16')) | |
print('Logdir', config.logdir) | |
config.logdir.mkdir(parents=True, exist_ok=True) | |
step = count_steps(config.traindir) | |
logger = tools.Logger(config.logdir, config.action_repeat * step) | |
print('Create envs.') | |
train_eps = tools.load_episodes(config.traindir, limit=config.dataset_size) | |
eval_eps = tools.load_episodes(config.evaldir, limit=1) | |
make1 = lambda mode: make_env(config, logger, mode, train_eps, eval_eps) | |
make2 = lambda mode: wrappers.Async(lambda: make1(mode), config.parallel) | |
train_envs = [make2('train') for _ in range(config.envs)] | |
eval_envs = [make2('eval') for _ in range(config.envs)] | |
acts = train_envs[0].action_space | |
config.num_actions = acts.n if hasattr(acts, 'n') else acts.shape[0] | |
prefill = max(0, config.prefill - count_steps(config.traindir)) | |
print(f'Prefill dataset ({prefill} steps).') | |
random_agent = lambda o, d, s: ([acts.sample() for _ in d], s) | |
tools.simulate(random_agent, train_envs, prefill) | |
tools.simulate(random_agent, eval_envs, episodes=1) | |
logger.step = config.action_repeat * count_steps(config.traindir) | |
print('Simulate agent.') | |
train_dataset = make_dataset(train_eps, config) | |
eval_dataset = iter(make_dataset(eval_eps, config)) | |
agent = Dreamer(config, logger, train_dataset) | |
if (config.logdir / 'variables.pkl').exists(): | |
try: | |
agent.load(config.logdir / 'variables.pkl') | |
except: | |
agent.load(config.logdir / 'variables-backup.pkl') | |
agent._should_pretrain._once = False | |
state = None | |
should_backup = tools.Every(config.backup_every) | |
while agent._step.numpy().item() < config.steps: | |
print('Process statistics.') | |
if config.gc: | |
gc.collect() | |
logger.scalar('python_objects', len(gc.get_objects())) | |
with open('/proc/self/statm') as f: | |
ram = int(f.read().split()[1]) * resource.getpagesize() / 1024 / 1024 / 1024 | |
logger.scalar('ram', ram) | |
logger.write() | |
print('Start evaluation.') | |
video_pred = agent._world_model.video_pred(next(eval_dataset)) | |
logger.video('eval_openl', video_pred) | |
eval_policy = functools.partial(agent, training=False) | |
tools.simulate(eval_policy, eval_envs, episodes=1) | |
print('Start training.') | |
state = tools.simulate(agent, train_envs, config.eval_every, state=state) | |
if config.ckpt: | |
agent.save(config.ckptdir / 'variables.pkl') | |
agent.save(config.logdir / 'variables.pkl') | |
if should_backup(agent._step.numpy().item()): | |
if config.ckpt: | |
agent.save(config.ckptdir / 'variables-backup.pkl') | |
agent.save(config.logdir / 'variables-backup.pkl') | |
for env in train_envs + eval_envs: | |
env.close() | |
if __name__ == '__main__': | |
print(os.getpid()) | |
try: | |
import colored_traceback | |
colored_traceback.add_hook() | |
except ImportError: | |
pass | |
parser = argparse.ArgumentParser() | |
for key, value in define_config().items(): | |
parser.add_argument(f'--{key}', type=tools.args_type(value), default=value) | |
main(parser.parse_args()) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import tensorflow as tf | |
from tensorflow.keras import layers as tfkl | |
from tensorflow.keras.mixed_precision import experimental as prec | |
from tensorflow_probability import distributions as tfd | |
import networks | |
import tools | |
class WorldModel(tools.Module): | |
def __init__(self, config): | |
self._config = config | |
self.encoder = networks.ConvEncoder( | |
config.cnn_depth, config.act, config.encoder_kernels, | |
config.cnn_embed_size or None, config.proprio) | |
if config.dynamics == 'rssm': | |
self.dynamics = networks.RSSM( | |
config.stoch_size, config.deter_size, config.deter_size, | |
deep=config.dynamics_deep) | |
elif config.dynamics == 'discrete_rssm': | |
self.dynamics = networks.DiscreteRSSM( | |
config.stoch_size, config.stoch_depth, | |
config.deter_size, config.deter_size) | |
else: | |
raise NotImplementedError(config.dynamics) | |
self.heads = {} | |
shape = config.size + (1 if config.atari_grayscale else 3,) | |
self.heads['image'] = networks.ConvDecoder( | |
config.cnn_depth, config.act, shape, config.decoder_kernels, | |
config.decoder_thin) | |
self.heads['reward'] = networks.DenseHead( | |
[], config.reward_layers, config.units, config.act) | |
if config.pred_discount: | |
self.heads['discount'] = networks.DenseHead( | |
[], config.discount_layers, config.units, config.act, dist='binary') | |
for name in config.grad_heads: | |
assert name in self.heads, name | |
self._model_opt = tools.Adam( | |
'model', config.model_lr, config.opt_eps, config.grad_clip, | |
config.weight_decay) | |
self._scales = dict( | |
reward=config.reward_scale, discount=config.discount_scale) | |
self._float = prec.global_policy().compute_dtype | |
def train(self, data): | |
data = self.preprocess(data) | |
with tf.GradientTape() as model_tape: | |
embed = self.encoder(data) | |
if self._config.dropout_embed: | |
embed = tf.nn.dropout(embed, self._config.dropout_embed) | |
post, prior = self.dynamics.observe(embed, data['action']) | |
kl_loss, kl_value = self.dynamics.kl_loss( | |
post, prior, self._config.kl_balance, self._config.free_nats, | |
self._config.kl_scale, tf.float32) | |
feat = self.dynamics.get_feat(post) | |
if self._config.dropout_feat: | |
feat = tf.nn.dropout(feat, self._config.dropout_feat) | |
likes = {} | |
for name, head in self.heads.items(): | |
grad_head = (name in self._config.grad_heads) | |
inp = feat if grad_head else tf.stop_gradient(feat) | |
pred = head(inp) # head(inp, tf.float32) | |
like = pred.log_prob(data[name]) | |
if name == 'discount' and self._config.discount_rare_bonus: | |
rare = tf.cast(data['discount'] < 0.5, self._float) | |
like *= 1 + self._config.discount_rare_bonus * rare | |
likes[name] = tf.reduce_mean(like) * self._scales.get(name, 1.0) | |
model_loss = kl_loss - sum(likes.values()) | |
model_parts = [self.encoder, self.dynamics] + list(self.heads.values()) | |
model_norm = self._model_opt(model_tape, model_loss, model_parts) | |
metrics = dict( | |
**{f'{name}_loss': -like for name, like in likes.items()}, | |
kl=kl_value, model_loss=model_loss, model_grad_norm=model_norm, | |
prior_ent=self.dynamics.get_dist(prior).entropy(), | |
post_ent=self.dynamics.get_dist(post).entropy(), | |
) | |
if self._config.precision == 16: | |
metrics['model_loss_scale'] = self._model_opt.loss_scale | |
return embed, post, feat, metrics | |
@tf.function | |
def preprocess(self, obs): | |
dtype = prec.global_policy().compute_dtype | |
obs = obs.copy() | |
obs['image'] = tf.cast(obs['image'], dtype) / 255.0 - 0.5 | |
obs['reward'] = getattr(tf, self._config.clip_rewards)(obs['reward']) | |
if 'discount' in obs: | |
obs['discount'] *= self._config.discount | |
return obs | |
@tf.function | |
def video_pred(self, data): | |
data = self.preprocess(data) | |
truth = data['image'][:6] + 0.5 | |
embed = self.encoder(data) | |
states, _ = self.dynamics.observe(embed[:6, :5], data['action'][:6, :5]) | |
recon = self.heads['image'](self.dynamics.get_feat(states)).mode()[:6] | |
init = {k: v[:, -1] for k, v in states.items()} | |
prior = self.dynamics.imagine(data['action'][:6, 5:], init) | |
openl = self.heads['image'](self.dynamics.get_feat(prior)).mode() | |
model = tf.concat([recon[:, :5] + 0.5, openl + 0.5], 1) | |
error = (model - truth + 1) / 2 | |
return tf.concat([truth, model, error], 2) | |
# @tf.function | |
# def video_pred_strip(self, data): | |
# data = self.preprocess(data) | |
# truth = data['image'][:6] + 0.5 | |
# embed = self.encoder(data) | |
# states, _ = self.dynamics.observe(embed[:6, :5], data['action'][:6, :5]) | |
# recon = self.heads['image'](self.dynamics.get_feat(states)).mode()[:6] | |
# init = {k: v[:, -1] for k, v in states.items()} | |
# prior = self.dynamics.imagine(data['action'][:6, 5:], init) | |
# openl = self.heads['image'](self.dynamics.get_feat(prior)).mode() | |
# model = tf.concat([recon[:, :5] + 0.5, openl + 0.5], 1) | |
# B, T, H, W, C = list(truth.shape) | |
# truth = tf.reshape(tf.transpose(truth, (0,2,1,3,4)), (B*H, T*W, C)) | |
# model = tf.reshape(tf.transpose(model, (0,2,1,3,4)), (B*H, T*W, C)) | |
# return truth, model | |
class ImagBehavior(tools.Module): | |
def __init__(self, config, world_model): | |
self._config = config | |
self._world_model = world_model | |
self.actor = networks.ActionHead( | |
config.num_actions, config.actor_layers, config.units, config.act, | |
config.actor_dist, config.actor_init_std, config.actor_disc) | |
self.value = networks.DenseHead( | |
[], config.value_layers, config.units, config.act) | |
if config.slow_value_target or config.slow_actor_target: | |
self._slow_value = networks.DenseHead( | |
[], config.value_layers, config.units, config.act) | |
self._updates = tf.Variable(0, tf.int64) | |
kw = dict(wd=config.weight_decay) | |
self._actor_opt = tools.Adam( | |
'actor', config.actor_lr, config.opt_eps, config.actor_grad_clip, **kw) | |
self._value_opt = tools.Adam( | |
'value', config.value_lr, config.opt_eps, config.value_grad_clip, **kw) | |
if config.reuse_task_behavior: | |
self.expl_value = networks.DenseHead( | |
[], config.value_layers, config.units, config.act) | |
if config.slow_value_target or config.slow_actor_target: | |
self._slow_expl_value = networks.DenseHead( | |
[], config.value_layers, config.units, config.act) | |
self._expl_value_opt = tools.Adam( | |
'expl_value', config.value_lr, config.opt_eps, config.value_grad_clip, **kw) | |
def train(self, start, objective=None, imagine=None, repeats=None): | |
assert bool(objective) != bool(imagine) | |
if self._config.slow_value_target or self._config.slow_actor_target: | |
if self._updates % self._config.slow_target_update == 0: | |
if not self._config.slow_target_soft: | |
for s, d in zip(self.value.variables, self._slow_value.variables): | |
d.assign(s) | |
if self._config.reuse_task_behavior: | |
for s, d in zip(self.expl_value.variables, self._slow_exp_value.variables): | |
d.assign(s) | |
else: | |
alpha = self._config.slow_target_moving_weight | |
for s, d in zip(self.value.variables, self._slow_value.variables): | |
d.assign(alpha*s+(1-alpha)*d) | |
if self._config.reuse_task_behavior: | |
for s, d in zip(self.expl_value.variables, self._slow_exp_value.variables): | |
d.assign(alpha*s+(1-alpha)*d) | |
self._updates.assign_add(1) | |
if self._config.pred_discount: # Last step could be terminal. | |
start = {k: v[:, :-1] for k, v in start.items()} | |
with tf.GradientTape() as actor_tape: | |
if objective: | |
imag_feat, imag_state, imag_action = self._imagine(start, repeats) | |
reward = objective(imag_feat, imag_state, imag_action) | |
else: | |
imag_feat, reward = imagine(start) | |
target, weights = self._compute_target( | |
imag_feat, reward, self._config.slow_actor_target) | |
actor_ent = self.actor( | |
tf.stop_gradient(imag_feat), tf.float32).entropy() | |
if self._config.actor_entropy: | |
bonus = self._config.actor_entropy * actor_ent[:-1] | |
else: | |
bonus = 0.0 | |
if self._config.reuse_task_behavior: | |
t = (1-self._config.task_expl_reward_weights) * target['task'] + self._config.task_expl_reward_weights * target['expl'] | |
actor_loss = -tf.reduce_mean(weights * (t + bonus)) | |
else: | |
actor_loss = -tf.reduce_mean(weights * (target + bonus)) | |
# Value loss | |
if self._config.slow_value_target != self._config.slow_actor_target: | |
target, weights = self._compute_target( | |
imag_feat, reward, self._config.slow_value_target) | |
if self._config.reuse_task_behavior: | |
target, expl_target = target['task'], target['expl'] | |
with tf.GradientTape() as value_tape: | |
like = self.value(imag_feat, tf.float32)[:-1].log_prob( | |
tf.stop_gradient(target)) | |
value_loss = -tf.reduce_mean(weights * like) | |
actor_norm = self._actor_opt(actor_tape, actor_loss, [self.actor]) | |
value_norm = self._value_opt(value_tape, value_loss, [self.value]) | |
metrics = dict( | |
actor_loss=actor_loss, actor_grad_norm=actor_norm, | |
value_loss=value_loss, value_grad_norm=value_norm, | |
actor_ent=actor_ent) | |
if self._config.precision == 16: | |
metrics['actor_loss_scale'] = self._actor_opt.loss_scale | |
metrics['value_loss_scale'] = self._value_opt.loss_scale | |
if self._config.reuse_task_behavior: | |
with tf.GradientTape() as expl_value_tape: | |
like = self.expl_value(imag_feat, tf.float32)[:-1].log_prob( | |
tf.stop_gradient(expl_target)) | |
expl_value_loss = -tf.reduce_mean(weights * like) | |
expl_value_norm = self._expl_value_opt(expl_value_tape, expl_value_loss, [self.expl_value]) | |
metrics['expl_value_loss'] = expl_value_loss | |
metrics['expl_value_grad_norm'] = expl_value_norm | |
if self._config.precision == 16: | |
metrics['expl_value_loss_scale'] = self._expl_value_opt.loss_scale | |
return metrics | |
def _compute_target(self, imag_feat, reward, slow): | |
if self._config.reuse_task_behavior: | |
expl_reward = tf.cast(reward['expl'], tf.float32) if tf.float32 else reward['expl'] | |
reward = tf.cast(reward['task'], tf.float32) if tf.float32 else reward['task'] | |
else: | |
reward = tf.cast(reward, tf.float32) if tf.float32 else reward | |
if 'discount' in self._world_model.heads: | |
discount = self._world_model.heads['discount']( | |
imag_feat, tf.float32).mean() | |
else: | |
discount = self._config.discount * tf.ones_like(reward) | |
if slow: | |
value = self._slow_value(imag_feat, tf.float32).mode() | |
else: | |
value = self.value(imag_feat, tf.float32).mode() | |
target = tools.lambda_return( | |
reward[:-1], value[:-1], discount[:-1], | |
bootstrap=value[-1], lambda_=self._config.discount_lambda, axis=0) | |
if self._config.reuse_task_behavior: | |
if slow: | |
expl_value = self._slow_expl_value(imag_feat, tf.float32).mode() | |
else: | |
expl_value = self.expl_value(imag_feat, tf.float32).mode() | |
expl_target = tools.lambda_return( | |
expl_reward[:-1], expl_value[:-1], discount[:-1], | |
bootstrap=expl_value[-1], lambda_=self._config.discount_lambda, axis=0) | |
target = dict(task=target, expl=expl_target) | |
weights = tf.stop_gradient(tf.math.cumprod(tf.concat( | |
[tf.ones_like(discount[:1]), discount[:-2]], 0), 0)) | |
return target, weights | |
def _imagine(self, start, repeats=None): | |
if repeats: | |
# Instead repeats dim after time dim, folded with time dim. | |
start = {k: tf.repeat(v, repeats, axis=1) for k, v in start.items()} | |
flatten = lambda x: tf.reshape(x, [-1] + list(x.shape[2:])) | |
start = {k: flatten(v) for k, v in start.items()} | |
def step(prev, _): | |
state, _, _ = prev | |
feat = self._world_model.dynamics.get_feat(state) | |
if self._config.dropout_imag: | |
feat = tf.nn.dropout(feat, self._config.dropout_imag) | |
action = self.actor(tf.stop_gradient(feat)).sample() | |
state = self._world_model.dynamics.img_step(state, action) | |
return state, feat, action | |
feat = 0 * self._world_model.dynamics.get_feat(start) | |
action = self.actor(feat).mode() | |
states, feats, actions = tools.static_scan( | |
step, tf.range(self._config.imag_horizon), | |
(start, feat, action)) | |
if repeats: | |
def unfold(tensor): | |
s = tensor.shape | |
return tf.reshape(tensor, [s[0], s[1] // repeats, repeats] + s[2:]) | |
states, feats, actions = tf.nest.map_structure( | |
unfold, (states, feats, actions)) | |
return feats, states, actions | |
class RandomBehavior(tools.Module): | |
def __init__(self, config): | |
self._config = config | |
self._float = prec.global_policy().compute_dtype | |
def actor(self, feat): | |
shape = feat.shape[:-1] + [self._config.num_actions] | |
if self._config.actor_dist == 'onehot': | |
return tools.OneHotDist(tf.zeros(shape)) | |
else: | |
ones = tf.ones(shape, self._float) | |
return tfd.Uniform(-ones, ones) | |
def train(self, start, feat, embed): | |
return {} | |
class DisagBehavior(tools.Module): | |
def __init__(self, config, world_model, behavior): | |
self._config = config | |
self._behavior = behavior | |
self.actor = self._behavior.actor | |
size = self._config.disag_project or { | |
'embed': config.cnn_embed_size or 32 * config.cnn_depth, | |
'stoch': config.stoch_size, | |
'deter': config.deter_size, | |
'feat': config.stoch_size + config.deter_size, | |
}[self._config.disag_target] | |
kw = dict( | |
shape=size, layers=config.disag_layers, units=config.disag_units, | |
act=config.act) | |
self._networks = [ | |
networks.DenseHead(**kw) for _ in range(config.disag_models)] | |
if self._config.disag_project: | |
self._project = tfkl.Dense(self._config.disag_project, trainable=False) | |
self._opt = tools.Adam( | |
'disag', config.model_lr, config.opt_eps, config.grad_clip, | |
config.weight_decay) | |
def train(self, start, feat, embed): | |
metrics = {} | |
target = { | |
'embed': embed, | |
'stoch': start['stoch'], | |
'deter': start['deter'], | |
'feat': feat, | |
}[self._config.disag_target] | |
if self._config.disag_project: | |
target = self._project(target) | |
metrics.update(self._train_ensemble(feat, target)) | |
objective = lambda f, s, a: self._intrinsic_reward(f) | |
metrics.update(self._behavior.train(start, objective)) | |
return metrics | |
def _intrinsic_reward(self, inputs): | |
preds = [head(inputs, tf.float32).mean() for head in self._networks] | |
disag = tf.reduce_mean(tf.math.reduce_std(preds, 0), -1) | |
if self._config.disag_log: | |
disag = tf.math.log(disag) | |
return self._config.disag_scale * disag | |
def _train_ensemble(self, inputs, targets): | |
if self._config.disag_offset: | |
targets = targets[:, self._config.disag_offset:] | |
inputs = inputs[:, :-self._config.disag_offset] | |
targets = tf.stop_gradient(targets) | |
inputs = tf.stop_gradient(inputs) | |
with tf.GradientTape() as tape: | |
preds = [head(inputs) for head in self._networks] | |
likes = [tf.reduce_mean(pred.log_prob(targets)) for pred in preds] | |
loss = -tf.reduce_sum(likes) | |
norm = self._opt(tape, loss, self._networks) | |
metrics = dict(disag_loss=loss, disag_grad_norm=norm) | |
if self._config.precision == 16: | |
metrics['disag_loss_scale'] = self._opt.loss_scale | |
return metrics | |
class DisagInfoBehavior(tools.Module): | |
def __init__(self, config, world_model, behavior): | |
self._config = config | |
self._behavior = behavior | |
self.actor = self._behavior.actor | |
size = self._config.disag_project or { | |
'embed': config.cnn_embed_size or 32 * config.cnn_depth, | |
'stoch': config.stoch_size, | |
'deter': config.deter_size, | |
'feat': config.stoch_size + config.deter_size, | |
}[self._config.disag_target] | |
kw = dict( | |
shape=size, layers=config.disag_layers, | |
units=config.disag_units, act=config.act, std='learned') | |
self._networks = [ | |
networks.DenseHead(**kw) for _ in range(config.disag_models)] | |
if self._config.disag_project: | |
self._project = tfkl.Dense(self._config.disag_project, trainable=False) | |
self._opt = tools.Adam( | |
'disag', config.model_lr, config.opt_eps, config.grad_clip, | |
config.weight_decay) | |
self._float = prec.global_policy().compute_dtype | |
def train(self, start, feat, embed): | |
metrics = {} | |
target = { | |
'embed': embed, | |
'stoch': start['stoch'], | |
'deter': start['deter'], | |
'feat': feat, | |
}[self._config.disag_target] | |
if self._config.disag_project: | |
target = self._project(target) | |
metrics.update(self._train_ensemble(feat, target)) | |
objective = lambda f, s, a: self._intrinsic_reward(f) | |
metrics.update(self._behavior.train(start, objective)) | |
return metrics | |
def _intrinsic_reward(self, inputs): | |
S = self._config.disag_samples | |
preds = [head(inputs, tf.float32) for head in self._networks] | |
cond = -tf.reduce_mean([ | |
pred.log_prob(pred.sample(S)) for pred in preds], [0, 1]) | |
if self._config.disag_save_memory: | |
cond = self._low_memory_cond_entropy(preds) | |
marg = self._low_memory_marg_entropy(preds) | |
else: | |
cond = -tf.reduce_mean([ | |
pred.log_prob(pred.sample(S)) for pred in preds], [0, 1]) | |
H, BT, F = list(preds[0].mean().shape) | |
K = len(self._networks) | |
weights = tfd.Categorical(tf.zeros((H, BT, K), self._float)) | |
mixture = tfd.Mixture(weights, preds) | |
marg = -tf.reduce_mean(mixture.log_prob(mixture.sample(S)), 0) | |
return self._config.disag_scale * (marg - cond) | |
def _train_ensemble(self, inputs, targets): | |
if self._config.disag_offset: | |
targets = targets[:, self._config.disag_offset:] | |
inputs = inputs[:, :-self._config.disag_offset] | |
targets = tf.stop_gradient(tf.cast(targets, tf.float32)) | |
inputs = tf.stop_gradient(inputs) | |
with tf.GradientTape() as tape: | |
preds = [head(inputs, tf.float32) for head in self._networks] | |
likes = [tf.reduce_mean(pred.log_prob(targets)) for pred in preds] | |
loss = -tf.reduce_sum(likes) | |
norm = self._opt(tape, loss, self._networks) | |
return dict(disag_loss=loss, disag_grad_norm=norm) | |
def _low_memory_cond_entropy(self, preds): | |
entropies = [] | |
for pred in preds: | |
logprobs = [] | |
for _ in range(self._config.disag_samples): | |
logprobs.append(pred.log_prob(pred.sample())) | |
entropies.append(-tf.reduce_mean(logprobs, 0)) | |
return tf.reduce_mean(entropies, 0) | |
def _low_memory_marg_entropy(self, preds): | |
H, BT, F = list(preds[0].mean().shape) | |
K = len(self._networks) | |
logprobs = [] | |
for _ in range(self._config.disag_samples): | |
sample = tf.zeros((H, BT, F), self._float) | |
index = tfd.Categorical(tf.zeros((H, BT, K), self._float)).sample() | |
for i, p in enumerate(preds): | |
sample += tf.cast(index == i, self._float)[:, :, None] * p.sample() | |
prob = tf.zeros((H, BT), self._float) | |
for p in preds: | |
prob += p.prob(sample) | |
prob /= K | |
logprobs.append(tf.math.log(prob + 1e-8)) | |
return -tf.reduce_mean(logprobs, 0) | |
class DisagNoiseBehavior(tools.Module): | |
def __init__(self, config, world_model, behavior): | |
self._config = config | |
self._behavior = behavior | |
self.actor = self._behavior.actor | |
size = self._config.disag_project or { | |
'embed': config.cnn_embed_size or 32 * config.cnn_depth, | |
'stoch': config.stoch_size, | |
'deter': config.deter_size, | |
'feat': config.stoch_size + config.deter_size, | |
}[self._config.disag_target] | |
kw = dict( | |
shape=size, layers=config.disag_layers, | |
units=config.disag_units, act=config.act, std='learned') | |
self._networks = [ | |
networks.DenseHead(**kw) for _ in range(config.disag_models)] | |
if self._config.disag_project: | |
self._project = tfkl.Dense(self._config.disag_project, trainable=False) | |
self._opt = tools.Adam( | |
'disag', config.model_lr, config.opt_eps, config.grad_clip, | |
config.weight_decay) | |
self._float = prec.global_policy().compute_dtype | |
def train(self, start, feat, embed): | |
metrics = {} | |
target = { | |
'embed': embed, | |
'stoch': start['stoch'], | |
'deter': start['deter'], | |
'feat': feat, | |
}[self._config.disag_target] | |
if self._config.disag_project: | |
target = self._project(target) | |
metrics.update(self._train_ensemble(feat, target)) | |
objective = lambda f, s, a: self._intrinsic_reward(f) | |
metrics.update(self._behavior.train(start, objective)) | |
return metrics | |
def _intrinsic_reward(self, inputs): | |
preds = [head(inputs, tf.float32) for head in self._networks] | |
means = [pred.mean() for pred in preds] | |
marg = tf.math.log(tf.reduce_mean(tf.math.reduce_std(means, 0), -1)) | |
cond = tf.reduce_mean([pred.entropy() for pred in preds], 0) | |
cond *= self._config.disag_noise_scale | |
return self._config.disag_scale * (marg - cond) | |
def _train_ensemble(self, inputs, targets): | |
if self._config.disag_offset: | |
targets = targets[:, self._config.disag_offset:] | |
inputs = inputs[:, :-self._config.disag_offset] | |
targets = tf.stop_gradient(tf.cast(targets, tf.float32)) | |
inputs = tf.stop_gradient(inputs) | |
with tf.GradientTape() as tape: | |
preds = [head(inputs, tf.float32) for head in self._networks] | |
likes = [tf.reduce_mean(pred.log_prob(targets)) for pred in preds] | |
loss = -tf.reduce_sum(likes) | |
norm = self._opt(tape, loss, self._networks) | |
return dict(disag_loss=loss, disag_grad_norm=norm) | |
class EmpowOpenBehavior(tools.Module): # I(S; A), open-loop, action space | |
def __init__(self, config, world_model, behavior): | |
self._config = config | |
self._world_model = world_model | |
self._behavior = behavior | |
self.actor = self._behavior.actor | |
def train(self, start, feat, embed, task_reward=None): | |
metrics = self._behavior.train(start, imagine=lambda s: self._imagine(s, task_reward)) | |
return metrics | |
def _imagine(self, start, task_reward=None): | |
parallel = 1 + self._config.empow_sequences | |
flatten = lambda x: tf.reshape(x, [-1] + list(x.shape[2:])) | |
start = {k: flatten(v) for k, v in start.items()} | |
start = {k: tf.tile(v, [parallel, 1]) for k, v in start.items()} | |
def step(prev, _): | |
state, _, _ = prev | |
feat = self._world_model.dynamics.get_feat(state) | |
if self._config.dropout_imag: | |
feat = tf.nn.dropout(feat, self._config.dropout_imag) | |
# action = self.actor(tf.stop_gradient( | |
# tf.split(feat, parallel, 0)[0])).sample() | |
# state = self._world_model.dynamics.img_step( | |
# state, tf.tile(action, [parallel, 1])) | |
action = self.actor(tf.stop_gradient(feat)).sample() | |
state = self._world_model.dynamics.img_step(state, action) | |
return state, feat, action | |
range_ = tf.range(self._config.imag_horizon) | |
feat = 0 * self._world_model.dynamics.get_feat(start) | |
action = 0 * self.actor(feat).mean() | |
_, feat, action = tools.static_scan(step, range_, (start, feat, action)) | |
shape = (feat.shape[0], parallel, feat.shape[1] // parallel, feat.shape[2]) | |
feat = tf.transpose(tf.reshape(feat, shape), [0, 2, 1, 3]) | |
# feedback_feat = feat[:, :, :1] # Time, batch, 1, features. | |
# openloop_feat = feat[:, :, 1:] # Time, batch, mixture, features. | |
feedback_feat = [] | |
openloop_feat = [] | |
for i in range(parallel): | |
feedback_feat.append(feat[:, :, i, :]) | |
openloop_feat.append(tf.concat([feat[:, :, :i, :], feat[:, :, i+1:, :]], axis=2)) | |
feedback_feat = tf.stack(feedback_feat) # Parallel, Time, batch, features. | |
openloop_feat = tf.stack(openloop_feat) # Parallel, Time, batch, mixture, features. | |
reward = self._intrinsic_reward(feedback_feat, openloop_feat, action) | |
feedback_feat = tf.transpose(feedback_feat, [1,0,2,3]) | |
feedback_feat = tf.reshape(feedback_feat, [feedback_feat.shape[0], -1, feedback_feat.shape[3]]) | |
reward = tf.transpose(reward, [1,0,2]) | |
reward = tf.reshape(reward, [reward.shape[0], -1]) | |
if task_reward: | |
t_reward = task_reward(feedback_feat, 0, 0) | |
reward = dict(expl=reward, task=t_reward) | |
return feedback_feat, reward | |
def _intrinsic_reward(self, feedback_feat, openloop_feat, action): | |
fb_dist = self.actor(feedback_feat, dtype=tf.float32) | |
ol_dist = tools.uniform_mixture( | |
self.actor(openloop_feat, dtype=tf.float32), dtype=tf.float32) | |
if self._config.empow_entropy == 'sample': | |
S = self._config.empow_samples | |
cond = -tf.reduce_mean(fb_dist.log_prob(fb_dist.sample(S)), 0) | |
marg = -tf.reduce_mean(ol_dist.log_prob(ol_dist.sample(S)), 0) | |
elif self._config.empow_entropy == 'analytic': | |
cond = fb_dist.entropy() | |
if self._config.empow_sequences == 1: | |
marg = ol_dist.entropy() | |
elif ol_dist.dtype is tf.int32: | |
marg = tools.cat_mixture_entropy(ol_dist) | |
else: | |
raise NotImplementedError(self._config.empow_entropy) | |
assert marg.shape == cond.shape == feedback_feat.shape[0:3] | |
return self._config.empow_scale * (marg - cond) | |
class EmpowActionBehavior(tools.Module): # I(S; A), closed-loop, action space | |
def __init__(self, config, world_model, behavior): | |
self._config = config | |
self._world_model = world_model | |
self._behavior = behavior | |
self.actor = self._behavior.actor | |
def train(self, start, feat, embed, task_reward=None): | |
repeats = dict( | |
sequences=self._config.empow_sequences, batch=None, time=None, | |
)[self._config.empow_marginal] | |
metrics = self._behavior.train( | |
start, lambda f, s, a: self._intrinsic_reward(f, s, a, task_reward=task_reward), repeats=repeats) | |
return metrics | |
def _intrinsic_reward(self, feat, state, action, task_reward=None): | |
if self._config.empow_marginal == 'sequences': | |
pass # Trajectory has trailing repeat dimension. | |
elif self._config.empow_marginal == 'batch': | |
pass # Trajectory has trailing batch dimension. | |
elif self._config.empow_marginal == 'time': | |
feat = tf.transpose(feat, [1, 0, 2]) | |
state = {k: tf.transpose(v, [1, 0, 2]) for k, v in state.items()} | |
dist = self.actor(feat, dtype=tf.float32) | |
mixture = tools.uniform_mixture(dist, dtype=tf.float32) | |
if self._config.empow_entropy == 'sample': | |
S = self._config.empow_samples | |
cond = -tf.reduce_mean(dist.log_prob(dist.sample(S)), 0) | |
marg = -tf.reduce_mean(mixture.log_prob(mixture.sample(S)), 0) | |
elif self._config.empow_entropy == 'analytic': | |
assert dist.dtype is tf.int32 | |
cond = dist.entropy() | |
marg = tools.cat_mixture_entropy(dist) | |
else: | |
raise NotImplementedError(self._config.empow_entropy) | |
marg = tf.repeat(marg[..., None], feat.shape[-2], -1) | |
assert marg.shape == cond.shape == feat.shape[:-1] | |
if self._config.empow_marginal == 'time': | |
marg = tf.transpose(marg, [1, 0]) | |
cond = tf.transpose(cond, [1, 0]) | |
reward = self._config.empow_scale * (marg - cond) | |
if task_reward: | |
t_reward = task_reward(feat, 0, 0) | |
reward = dict(expl=reward, task=t_reward) | |
return reward | |
class EmpowStateBehavior(tools.Module): # I(S; A), closed-loop, state space | |
def __init__(self, config, world_model, behavior): | |
self._config = config | |
self._world_model = world_model | |
self._behavior = behavior | |
self.actor = self._behavior.actor | |
def train(self, start, feat, embed, task_reward=None): | |
repeats = dict( | |
sequences=self._config.empow_sequences, batch=None, time=None, | |
)[self._config.empow_marginal] | |
metrics = self._behavior.train( | |
start, lambda f, s, a: self._intrinsic_reward(f, s, a, task_reward), repeats=repeats) | |
return metrics | |
def _intrinsic_reward(self, feat, state, action, task_reward=None): | |
if self._config.empow_marginal == 'sequences': | |
pass # Trajectory has trailing repeat dimension. | |
elif self._config.empow_marginal == 'batch': | |
pass # Trajectory has trailing batch dimension. | |
elif self._config.empow_marginal == 'time': | |
feat = tf.transpose(feat, [1, 0, 2]) | |
state = {k: tf.transpose(v, [1, 0, 2]) for k, v in state.items()} | |
dist = self._world_model.dynamics.get_dist(state, dtype=tf.float32) | |
mixture = tools.uniform_mixture(dist, dtype=tf.float32) | |
S = self._config.empow_samples | |
cond = -tf.reduce_mean(dist.log_prob(dist.sample(S)), 0) | |
marg = -tf.reduce_mean(mixture.log_prob(mixture.sample(S)), 0) | |
marg = tf.repeat(marg[..., None], feat.shape[-2], -1) | |
assert marg.shape == cond.shape == feat.shape[:-1] | |
if self._config.empow_marginal == 'time': | |
marg = tf.transpose(marg, [1, 0]) | |
cond = tf.transpose(cond, [1, 0]) | |
reward = self._config.empow_scale * (marg - cond) | |
if task_reward: | |
t_reward = task_reward(feat, 0, 0) | |
reward = dict(expl=reward, task=t_reward) | |
return reward | |
class EmpowStepBehavior(tools.Module): # I(S'; A), state space | |
def __init__(self, config, world_model, behavior): | |
self._config = config | |
self._world_model = world_model | |
self._behavior = behavior | |
self.actor = self._behavior.actor | |
def train(self, start, feat, embed, task_reward=None): | |
metrics = self._behavior.train(start, lambda f, s, a: self._intrinsic_reward(f, s, a, task_reward)) | |
return metrics | |
def _intrinsic_reward(self, feat, state, action, task_reward=None): | |
T, B, _ = feat.shape | |
K = self._config.empow_sequences | |
state = {k: tf.repeat(v, K, axis=1) for k, v in state.items()} | |
action = self.actor(self._world_model.dynamics.get_feat(state)).sample() | |
state = self._world_model.dynamics.img_step(state, action) | |
state = {k: tf.reshape(v, (T,B,K,v.shape[-1])) for k, v in state.items()} | |
dist = self._world_model.dynamics.get_dist(state, dtype=tf.float32) | |
mixture = tools.uniform_mixture(dist, dtype=tf.float32) | |
S = self._config.empow_samples | |
cond = -tf.reduce_mean(dist.log_prob(dist.sample(S)), (0, -1)) | |
marg = -tf.reduce_mean(mixture.log_prob(mixture.sample(S)), 0) | |
reward = self._config.empow_scale * (marg - cond) | |
if task_reward: | |
t_reward = task_reward(feat, 0, 0) | |
reward = dict(expl=reward, task=t_reward) | |
return reward | |
class EmpowVIMBehavior(tools.Module): # I(S';A), action space | |
def __init__(self, config, world_model, behavior): | |
self._config = config | |
self._world_model = world_model | |
self._behavior = behavior | |
self.actor = self._behavior.actor | |
self._tau = config.temperature # temperature | |
self._plan_ss = networks.ActionHead( | |
config.num_actions, config.actor_layers, config.units, config.act, | |
config.actor_dist, config.actor_init_std, config.actor_disc) # p(a_0|s, s') | |
self._plan_sas = networks.ActionHead( | |
config.num_actions, config.actor_layers, config.units, config.act, | |
config.actor_dist, config.actor_init_std, config.actor_disc) # p(a_t|s, a_t-1, s') | |
self._source = networks.ActionHead( | |
config.num_actions, config.actor_layers, config.units, config.act, | |
config.actor_dist, config.actor_init_std, config.actor_disc) | |
self._offset = networks.DenseHead( | |
[], config.actor_layers, config.units, config.act) | |
self._plan_opt = tools.Adam( | |
'plan', config.actor_lr, config.opt_eps, config.grad_clip, | |
config.weight_decay) | |
self._source_opt = tools.Adam( | |
'source', config.actor_lr, config.opt_eps, config.grad_clip, | |
config.weight_decay) # for both source and offset | |
def train(self, start, feat, embed, task_reward=None): | |
metrics = {} | |
metrics.update(self._behavior.train(start, lambda f, s, a: self._intrinsic_reward(f, s, a, task_reward))) | |
metrics.update(self._train_vim(start)) | |
return metrics | |
def _intrinsic_reward(self, feat, state, action, task_reward=None): | |
reward = self._tau * self._config.vim_horizon * self._offset(feat).mode() | |
if task_reward: | |
t_reward = task_reward(feat, 0, 0) | |
reward = dict(expl=reward, task=t_reward) | |
return reward | |
def _imagine(self, start): # don't use this for self._behavior.train | |
flatten = lambda x: tf.reshape(x, [-1] + list(x.shape[2:])) | |
start = {k: flatten(v) for k, v in start.items()} | |
def step(prev, _): | |
state, _, _ = prev | |
feat = self._world_model.dynamics.get_feat(state) | |
if self._config.dropout_imag: | |
feat = tf.nn.dropout(feat, self._config.dropout_imag) | |
action = self._source(tf.stop_gradient(feat)).sample() # source instead of actor | |
state = self._world_model.dynamics.img_step(state, action) | |
return state, feat, action | |
range_ = tf.range(self._config.imag_horizon) | |
feat = 0 * self._world_model.dynamics.get_feat(start) | |
action = 0 * self._source(feat).mean() | |
_, feat, action = tools.static_scan(step, range_, (start, feat, action)) | |
return feat, action | |
def _train_vim(self, start): | |
feat, action = self._imagine(start) | |
h = self._config.vim_horizon | |
acts = [] | |
for i in range(len(action)-h): | |
acts.append(action[i:i+h]) | |
acts = tf.stack(acts) # time-horizon, horizon, batch, features | |
ss_feat = tf.concat([feat[:-h], feat[h:]], axis=-1) # time-horizon, batch, features | |
if h > 1: | |
sas_feat = tf.tile(tf.expand_dims(ss_feat, axis=1), [1, h-1, 1, 1]) | |
sas_feat = tf.concat([sas_feat, acts[:, :-1]], axis=-1) # time-horizon, horizon-1, batch, features | |
action, acts = tf.stop_gradient(tf.cast(action, tf.float32)), tf.stop_gradient(tf.cast(acts, tf.float32)) | |
metrics = {} | |
with tf.GradientTape() as tape: | |
plan_ss_pred = self._plan_ss(ss_feat, dtype=tf.float32) | |
log_plan_ss_prob = plan_ss_pred.log_prob(acts[:, 0]) | |
loss = -tf.reduce_mean(log_plan_ss_prob) | |
if h > 1: | |
plan_sas_pred = self._plan_sas(sas_feat, dtype=tf.float32) | |
log_plan_sas_prob = tf.reduce_sum(plan_sas_pred.log_prob(acts[:, 1:]), axis=1) | |
loss += -tf.reduce_mean(log_plan_sas_prob) | |
loss *= 1./h | |
if h > 1: | |
norm = self._plan_opt(tape, loss, [self._plan_ss, self._plan_sas]) | |
else: | |
norm = self._plan_opt(tape, loss, [self._plan_ss]) | |
metrics['plan_grad_norm'] = norm | |
metrics['plan_loss'] = loss | |
if self._config.precision == 16: | |
metrics['plan_dynamics_loss_scale'] = self._plan_opt.loss_scale | |
with tf.GradientTape() as tape: | |
offset = self._offset(feat[:-h], dtype=tf.float32).mode() | |
action_pred = self._source(feat[:-h], dtype=tf.float32) | |
r_sa = action_pred.log_prob(action[:-h]) + offset | |
if h > 1: | |
log_plan_prob = log_plan_ss_prob + log_plan_sas_prob | |
else: | |
log_plan_prob = log_plan_ss_prob | |
loss = tf.reduce_mean(tf.square(1./self._tau * 1./h * tf.stop_gradient(log_plan_prob) - r_sa)) | |
norm = self._source_opt(tape, loss, [self._source, self._offset]) | |
metrics['source_grad_norm'] = norm | |
metrics['source_loss'] = loss | |
if self._config.precision == 16: | |
metrics['source_loss_scale'] = self._source_opt.loss_scale | |
return metrics |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment