Skip to content

Instantly share code, notes, and snippets.

@alexlimh
Created April 27, 2020 19:16
Show Gist options
  • Save alexlimh/a362b22025d2708e8a8cc02f3aaf6e73 to your computer and use it in GitHub Desktop.
Save alexlimh/a362b22025d2708e8a8cc02f3aaf6e73 to your computer and use it in GitHub Desktop.
import argparse
import collections
import functools
import gc
import os
import pathlib
import resource
import sys
import warnings
warnings.filterwarnings('ignore', '.*box bound precision lowered.*')
# os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
os.environ['MUJOCO_GL'] = 'egl'
import numpy as np
import tensorflow as tf
from tensorflow.keras.mixed_precision import experimental as prec
# tf.get_logger().setLevel('ERROR')
from tensorflow_probability import distributions as tfd
sys.path.append(str(pathlib.Path(__file__).parent))
import models
import tools
import wrappers
def define_config():
config = tools.AttrDict()
# General.
config.load_config = False
config.ckpt = True
config.logdir = pathlib.Path('.')
config.ckptdir = pathlib.Path('.')
config.traindir = None
config.evaldir = None
config.seed = 0
config.steps = 1e7
config.eval_every = 1e4
config.log_every = 1e4
config.backup_every = 1e3
config.gpu_growth = True
config.precision = 16
config.debug = False
config.expl_gifs = False
config.compress_dataset = False
config.gc = False
config.deleps = False
# Environment.
config.task = 'dmc_walker_walk'
config.size = (64, 64)
config.parallel = 'none'
config.envs = 1
config.action_repeat = 2
config.time_limit = 1000
config.prefill = 5000
config.eval_noise = 0.0
config.clip_rewards = 'identity'
config.atari_grayscale = False
config.atari_lifes = False
config.atari_fire_start = False
config.dmlab_path = None
# Model.
config.dynamics = 'rssm'
config.grad_heads = ('image', 'reward')
config.deter_size = 200
config.stoch_size = 50
config.stoch_depth = 10
config.dynamics_deep = False
config.units = 400
config.reward_layers = 2
config.discount_layers = 3
config.value_layers = 3
config.actor_layers = 4
config.act = 'elu'
config.cnn_depth = 32
config.encoder_kernels = (4, 4, 4, 4)
config.decoder_kernels = (5, 5, 6, 6)
config.decoder_thin = True
config.cnn_embed_size = 0
config.kl_scale = 1.0
config.free_nats = 3.0
config.kl_balance = 0.8
config.pred_discount = False
config.discount_rare_bonus = 0.0
config.discount_scale = 10.0
config.reward_scale = 2.0
config.weight_decay = 0.0
config.proprio = False
# Training.
config.batch_size = 50
config.batch_length = 50
config.train_every = 500
config.train_steps = 100
config.pretrain = 100
config.model_lr = 6e-4
config.value_lr = 8e-5
config.actor_lr = 8e-5
config.grad_clip = 100.0
config.opt_eps = 1e-5
config.value_grad_clip = 100.0
config.actor_grad_clip = 100.0
config.dataset_size = 0
config.dropout_embed = 0.0
config.dropout_feat = 0.0
config.dropout_imag = 0.0
config.oversample_ends = False
config.slow_value_target = False
config.slow_actor_target = False
config.slow_target_update = 0
config.slow_target_soft = False
config.slow_target_moving_weight = 0.1
# Behavior.
config.discount = 0.99
config.discount_lambda = 0.95
config.imag_horizon = 15
config.actor_dist = 'tanh_normal'
config.actor_disc = 5
config.actor_entropy = 0.0
config.actor_init_std = 5.0
config.expl = 'additive_gaussian'
config.expl_amount = 0.3
config.eval_state_mean = False
config.reuse_task_behavior = False
config.task_expl_reward_weights = 0.5
# Exploration.
config.expl_behavior = 'greedy'
config.expl_until = 0
config.disag_target = 'stoch' # embed
config.disag_scale = 5.0
config.disag_log = True
config.disag_samples = 10
config.disag_models = 10
config.disag_offset = 1
config.disag_layers = 4
config.vim_layers = 4
config.disag_units = 400
config.disag_project = 0 # 100
config.disag_noise_scale = 10.0
config.disag_save_memory = True
config.empow_scale = 1.0
config.empow_entropy = 'auto'
config.empow_marginal = 'sequences'
config.empow_sequences = 5 # 3
config.empow_samples = 5 # 3
config.temperature = 1.0
config.vim_horizon = 5
return config
class Dreamer(tools.Module):
def __init__(self, config, logger, dataset):
self._config = config
self._logger = logger
self._float = prec.global_policy().compute_dtype
self._should_log = tools.Every(config.log_every)
self._should_train = tools.Every(config.train_every)
self._should_pretrain = tools.Once()
self._should_expl = tools.Until(int(config.expl_until / config.action_repeat))
self._metrics = collections.defaultdict(tf.metrics.Mean)
with tf.device('cpu:0'):
self._step = tf.Variable(count_steps(config.traindir), dtype=tf.int64)
# self._strategy = tf.distribute.MirroredStrategy()
# with self._strategy.scope():
# self._dataset = self._strategy.experimental_distribute_dataset(dataset)
self._dataset = iter(dataset)
self._world_model = models.WorldModel(config)
self._behavior = models.ImagBehavior(config, self._world_model)
if not config.reuse_task_behavior:
behavior = models.ImagBehavior(config, self._world_model)
else:
behavior = self._behavior
self._expl_behavior = dict(
greedy=lambda: self._behavior,
random=lambda: models.RandomBehavior(config),
disag=lambda: models.DisagBehavior(config, self._world_model, behavior),
disag_info=lambda: models.DisagInfoBehavior(config, self._world_model, behavior),
disag_noise=lambda: models.DisagNoiseBehavior(config, self._world_model, behavior),
empow_open=lambda: models.EmpowOpenBehavior(config, self._world_model, behavior),
empow_action=lambda: models.EmpowActionBehavior(config, self._world_model, behavior),
empow_state=lambda: models.EmpowStateBehavior(config, self._world_model, behavior),
empow_step=lambda: models.EmpowStepBehavior(config, self._world_model, behavior),
empow_vim=lambda: models.EmpowVIMBehavior(config, self._world_model, behavior),
)[config.expl_behavior]()
# Train step to initialize variables including optimizer statistics.
self.train(next(self._dataset))
def __call__(self, obs, reset, state=None, training=True):
step = self._step.numpy().item()
if state is not None and reset.any():
mask = tf.cast(1 - reset, self._float)[:, None]
state = tf.nest.map_structure(lambda x: x * mask, state)
if training and self._should_train(step):
# with self._strategy.scope():
steps = (
self._config.pretrain if self._should_pretrain()
else self._config.train_steps)
for _ in range(steps):
self.train(next(self._dataset))
if self._should_log(step):
for name, mean in self._metrics.items():
self._logger.scalar(name, float(mean.result()))
mean.reset_states()
openl = self._world_model.video_pred(next(self._dataset))
self._logger.video('train_openl', openl)
self._logger.write(fps=True)
action, state = self._policy(obs, state, training)
if training:
self._step.assign_add(len(reset))
self._logger.step = self._config.action_repeat * self._step.numpy().item()
return action, state
@tf.function
def _policy(self, obs, state, training):
if state is None:
batch_size = len(obs['image'])
latent = self._world_model.dynamics.initial(len(obs['image']))
action = tf.zeros((batch_size, self._config.num_actions), self._float)
else:
latent, action = state
embed = self._world_model.encoder(self._world_model.preprocess(obs))
latent, _ = self._world_model.dynamics.obs_step(latent, action, embed)
if self._config.eval_state_mean:
latent['stoch'] = latent['mean']
feat = self._world_model.dynamics.get_feat(latent)
if not training:
action = self._behavior.actor(feat).mode()
elif self._should_expl(self._step):
action = self._expl_behavior.actor(feat).sample()
else:
action = self._behavior.actor(feat).sample()
action = self._exploration(action, training)
state = (latent, action)
return action, state
def _exploration(self, action, training):
amount = self._config.expl_amount if training else self._config.eval_noise
if amount == 0:
return action
amount = tf.cast(amount, self._float)
if self._config.expl == 'additive_gaussian':
return tf.clip_by_value(tfd.Normal(action, amount).sample(), -1, 1)
if self._config.expl == 'epsilon_greedy':
probs = amount / self._config.num_actions + (1 - amount) * action
return tools.OneHotDist(probs=probs).sample()
raise NotImplementedError(self._config.expl)
# @tf.function
# def train(self, data):
# self._strategy.experimental_run_v2(self._train, (data,))
@tf.function
def train(self, data):
metrics = {}
embed, post, feat, mets = self._world_model.train(data)
metrics.update(mets)
reward = lambda f, s, a: self._world_model.heads['reward'](f).mode()
if not self._config.reuse_task_behavior:
metrics.update(self._behavior.train(post, reward))
if self._config.expl_behavior != 'greedy':
if self._config.reuse_task_behavior:
mets = self._expl_behavior.train(post, feat, embed, reward)
else:
mets = self._expl_behavior.train(post, feat, embed)
metrics.update({'expl_' + key: value for key, value in mets.items()})
# if tf.distribute.get_replica_context().replica_id_in_sync_group == 0:
# if self._config.pred_discount:
# metrics.update(self._episode_end_precision_recall(data, feat))
for name, value in metrics.items():
self._metrics[name].update_state(value)
# def _episode_end_precision_recall(self, data, feat):
# pred = self._world_model.heads['discount'](feat).mean()
# ppos, pneg = pred <= 0.5, pred > 0.5
# dpos, dneg = data['discount'] <= 0.5, data['discount'] > 0.5
# tp = tf.reduce_sum(tf.cast(tf.logical_and(ppos, dpos), self._float))
# tn = tf.reduce_sum(tf.cast(tf.logical_and(pneg, dneg), self._float))
# fp = tf.reduce_sum(tf.cast(tf.logical_and(ppos, dneg), self._float))
# fn = tf.reduce_sum(tf.cast(tf.logical_and(pneg, dpos), self._float))
# tpr, tnr, ppv = tp / (tp + fn), tn / (tn + fp), tp / (tp + fp)
# ba, f1 = (tpr + tnr) / 2, 2 * ppv * tpr / (ppv + tpr)
# any_ = tf.cast(tf.reduce_any(dpos), self._float)
# metrics = dict(
# ee_tp=tp, ee_tn=tn, ee_fp=fp, ee_fn=fn, ee_tpr=tpr, ee_tnr=tnr,
# ee_ppv=ppv, ee_ba=ba, ee_f1=f1, ee_any=any_)
# return {
# k: tf.where(tf.math.is_nan(v), tf.zeros_like(v), v)
# for k, v in metrics.items()}
def count_steps(directory):
return sum(int(str(n).split('-')[-1][:-4]) - 1 for n in directory.glob('*.npz'))
def make_dataset(episodes, config):
example = episodes[next(iter(episodes.keys()))]
example = tools.decompress_episode(example)
types = {k: v.dtype for k, v in example.items()}
shapes = {k: (None,) + v.shape[1:] for k, v in example.items()}
generator = lambda: tools.sample_episodes(
episodes, config.batch_length, config.oversample_ends)
dataset = tf.data.Dataset.from_generator(generator, types, shapes)
dataset = dataset.batch(config.batch_size, drop_remainder=True)
dataset = dataset.prefetch(10)
return dataset
def make_env(config, logger, mode, train_eps, eval_eps):
suite, task = config.task.split('_', 1)
if suite == 'dmc' or suite == 'dmctoy':
env = wrappers.DeepMindControl(task, config.action_repeat, config.size)
env = wrappers.NormalizeActions(env)
elif suite == 'atari':
env = wrappers.Atari(
task, config.action_repeat, config.size,
grayscale=config.atari_grayscale, sticky_actions=True,
life_done=config.atari_lifes and (mode == 'train'))
env = wrappers.OneHotAction(env)
elif suite == 'dmlab':
args = task, config.action_repeat, config.size
env = wrappers.DeepMindLab(*args, path=config.dmlab_path)
env = wrappers.OneHotAction(env)
elif suite == 'mc':
args = task, mode, config.size, config.action_repeat
env = wrappers.Minecraft(*args)
env = wrappers.OneHotAction(env)
else:
raise NotImplementedError(suite)
env = wrappers.TimeLimit(env, config.time_limit)
callbacks = [functools.partial(
process_episode, config, logger, mode, train_eps, eval_eps)]
env = wrappers.Collect(env, callbacks, config.precision)
env = wrappers.RewardObs(env)
return env
def process_episode(config, logger, mode, train_eps, eval_eps, episode):
directory = dict(train=config.traindir, eval=config.evaldir)[mode]
cache = dict(train=train_eps, eval=eval_eps)[mode]
filename = tools.save_episodes(directory, [episode])[0]
length = len(episode['reward']) - 1
score = float(episode['reward'].astype(np.float32).sum())
video = episode['image']
if mode == 'eval':
if config.deleps:
for key, value in list(cache.items()):
del cache[key]
del value
cache.clear()
if mode == 'train' and config.dataset_size:
total = 0
for key, ep in reversed(sorted(cache.items(), key=lambda x: x[0])):
if total <= config.dataset_size - length:
total += len(ep['reward']) - 1
else:
del cache[key]
if config.deleps:
del ep
logger.scalar('dataset_size', total + length)
if config.compress_dataset:
episode = tools.compress_episode(episode, (np.uint8,))
cache[str(filename)] = episode
print(f'{mode.title()} episode has {length} steps and return {score:.1f}.')
logger.scalar(f'{mode}_return', score)
logger.scalar(f'{mode}_length', length)
logger.scalar(f'{mode}_episodes', len(cache))
if mode == 'eval' or config.expl_gifs:
logger.video(f'{mode}_policy', video[None])
logger.write()
def main(config):
if config.load_config:
from configs import default_configs
suite, _ = config.task.split('_', 1)
default = default_configs[suite]
config = vars(config)
for key, value in default.items():
config[key] = value
config = tools.AttrDict(config)
print(config)
config.traindir = config.traindir or config.logdir / 'train_eps'
config.evaldir = config.evaldir or config.logdir / 'eval_eps'
config.steps //= config.action_repeat
config.eval_every //= config.action_repeat
config.log_every //= config.action_repeat
config.backup_every //= config.action_repeat
config.time_limit //= config.action_repeat
config.act = getattr(tf.nn, config.act)
if config.empow_entropy == 'auto':
options = dict(
tanh_normal='sample', onehot='sample', discretized='analytic')
config.empow_entropy = options[config.actor_dist]
if config.debug:
tf.config.experimental_run_functions_eagerly(True)
if config.gpu_growth:
for gpu in tf.config.experimental.list_physical_devices('GPU'):
tf.config.experimental.set_memory_growth(gpu, True)
assert config.precision in (16, 32), config.precision
if config.precision == 16:
prec.set_policy(prec.Policy('mixed_float16'))
print('Logdir', config.logdir)
config.logdir.mkdir(parents=True, exist_ok=True)
step = count_steps(config.traindir)
logger = tools.Logger(config.logdir, config.action_repeat * step)
print('Create envs.')
train_eps = tools.load_episodes(config.traindir, limit=config.dataset_size)
eval_eps = tools.load_episodes(config.evaldir, limit=1)
make1 = lambda mode: make_env(config, logger, mode, train_eps, eval_eps)
make2 = lambda mode: wrappers.Async(lambda: make1(mode), config.parallel)
train_envs = [make2('train') for _ in range(config.envs)]
eval_envs = [make2('eval') for _ in range(config.envs)]
acts = train_envs[0].action_space
config.num_actions = acts.n if hasattr(acts, 'n') else acts.shape[0]
prefill = max(0, config.prefill - count_steps(config.traindir))
print(f'Prefill dataset ({prefill} steps).')
random_agent = lambda o, d, s: ([acts.sample() for _ in d], s)
tools.simulate(random_agent, train_envs, prefill)
tools.simulate(random_agent, eval_envs, episodes=1)
logger.step = config.action_repeat * count_steps(config.traindir)
print('Simulate agent.')
train_dataset = make_dataset(train_eps, config)
eval_dataset = iter(make_dataset(eval_eps, config))
agent = Dreamer(config, logger, train_dataset)
if (config.logdir / 'variables.pkl').exists():
try:
agent.load(config.logdir / 'variables.pkl')
except:
agent.load(config.logdir / 'variables-backup.pkl')
agent._should_pretrain._once = False
state = None
should_backup = tools.Every(config.backup_every)
while agent._step.numpy().item() < config.steps:
print('Process statistics.')
if config.gc:
gc.collect()
logger.scalar('python_objects', len(gc.get_objects()))
with open('/proc/self/statm') as f:
ram = int(f.read().split()[1]) * resource.getpagesize() / 1024 / 1024 / 1024
logger.scalar('ram', ram)
logger.write()
print('Start evaluation.')
video_pred = agent._world_model.video_pred(next(eval_dataset))
logger.video('eval_openl', video_pred)
eval_policy = functools.partial(agent, training=False)
tools.simulate(eval_policy, eval_envs, episodes=1)
print('Start training.')
state = tools.simulate(agent, train_envs, config.eval_every, state=state)
if config.ckpt:
agent.save(config.ckptdir / 'variables.pkl')
agent.save(config.logdir / 'variables.pkl')
if should_backup(agent._step.numpy().item()):
if config.ckpt:
agent.save(config.ckptdir / 'variables-backup.pkl')
agent.save(config.logdir / 'variables-backup.pkl')
for env in train_envs + eval_envs:
env.close()
if __name__ == '__main__':
print(os.getpid())
try:
import colored_traceback
colored_traceback.add_hook()
except ImportError:
pass
parser = argparse.ArgumentParser()
for key, value in define_config().items():
parser.add_argument(f'--{key}', type=tools.args_type(value), default=value)
main(parser.parse_args())
import tensorflow as tf
from tensorflow.keras import layers as tfkl
from tensorflow.keras.mixed_precision import experimental as prec
from tensorflow_probability import distributions as tfd
import networks
import tools
class WorldModel(tools.Module):
def __init__(self, config):
self._config = config
self.encoder = networks.ConvEncoder(
config.cnn_depth, config.act, config.encoder_kernels,
config.cnn_embed_size or None, config.proprio)
if config.dynamics == 'rssm':
self.dynamics = networks.RSSM(
config.stoch_size, config.deter_size, config.deter_size,
deep=config.dynamics_deep)
elif config.dynamics == 'discrete_rssm':
self.dynamics = networks.DiscreteRSSM(
config.stoch_size, config.stoch_depth,
config.deter_size, config.deter_size)
else:
raise NotImplementedError(config.dynamics)
self.heads = {}
shape = config.size + (1 if config.atari_grayscale else 3,)
self.heads['image'] = networks.ConvDecoder(
config.cnn_depth, config.act, shape, config.decoder_kernels,
config.decoder_thin)
self.heads['reward'] = networks.DenseHead(
[], config.reward_layers, config.units, config.act)
if config.pred_discount:
self.heads['discount'] = networks.DenseHead(
[], config.discount_layers, config.units, config.act, dist='binary')
for name in config.grad_heads:
assert name in self.heads, name
self._model_opt = tools.Adam(
'model', config.model_lr, config.opt_eps, config.grad_clip,
config.weight_decay)
self._scales = dict(
reward=config.reward_scale, discount=config.discount_scale)
self._float = prec.global_policy().compute_dtype
def train(self, data):
data = self.preprocess(data)
with tf.GradientTape() as model_tape:
embed = self.encoder(data)
if self._config.dropout_embed:
embed = tf.nn.dropout(embed, self._config.dropout_embed)
post, prior = self.dynamics.observe(embed, data['action'])
kl_loss, kl_value = self.dynamics.kl_loss(
post, prior, self._config.kl_balance, self._config.free_nats,
self._config.kl_scale, tf.float32)
feat = self.dynamics.get_feat(post)
if self._config.dropout_feat:
feat = tf.nn.dropout(feat, self._config.dropout_feat)
likes = {}
for name, head in self.heads.items():
grad_head = (name in self._config.grad_heads)
inp = feat if grad_head else tf.stop_gradient(feat)
pred = head(inp) # head(inp, tf.float32)
like = pred.log_prob(data[name])
if name == 'discount' and self._config.discount_rare_bonus:
rare = tf.cast(data['discount'] < 0.5, self._float)
like *= 1 + self._config.discount_rare_bonus * rare
likes[name] = tf.reduce_mean(like) * self._scales.get(name, 1.0)
model_loss = kl_loss - sum(likes.values())
model_parts = [self.encoder, self.dynamics] + list(self.heads.values())
model_norm = self._model_opt(model_tape, model_loss, model_parts)
metrics = dict(
**{f'{name}_loss': -like for name, like in likes.items()},
kl=kl_value, model_loss=model_loss, model_grad_norm=model_norm,
prior_ent=self.dynamics.get_dist(prior).entropy(),
post_ent=self.dynamics.get_dist(post).entropy(),
)
if self._config.precision == 16:
metrics['model_loss_scale'] = self._model_opt.loss_scale
return embed, post, feat, metrics
@tf.function
def preprocess(self, obs):
dtype = prec.global_policy().compute_dtype
obs = obs.copy()
obs['image'] = tf.cast(obs['image'], dtype) / 255.0 - 0.5
obs['reward'] = getattr(tf, self._config.clip_rewards)(obs['reward'])
if 'discount' in obs:
obs['discount'] *= self._config.discount
return obs
@tf.function
def video_pred(self, data):
data = self.preprocess(data)
truth = data['image'][:6] + 0.5
embed = self.encoder(data)
states, _ = self.dynamics.observe(embed[:6, :5], data['action'][:6, :5])
recon = self.heads['image'](self.dynamics.get_feat(states)).mode()[:6]
init = {k: v[:, -1] for k, v in states.items()}
prior = self.dynamics.imagine(data['action'][:6, 5:], init)
openl = self.heads['image'](self.dynamics.get_feat(prior)).mode()
model = tf.concat([recon[:, :5] + 0.5, openl + 0.5], 1)
error = (model - truth + 1) / 2
return tf.concat([truth, model, error], 2)
# @tf.function
# def video_pred_strip(self, data):
# data = self.preprocess(data)
# truth = data['image'][:6] + 0.5
# embed = self.encoder(data)
# states, _ = self.dynamics.observe(embed[:6, :5], data['action'][:6, :5])
# recon = self.heads['image'](self.dynamics.get_feat(states)).mode()[:6]
# init = {k: v[:, -1] for k, v in states.items()}
# prior = self.dynamics.imagine(data['action'][:6, 5:], init)
# openl = self.heads['image'](self.dynamics.get_feat(prior)).mode()
# model = tf.concat([recon[:, :5] + 0.5, openl + 0.5], 1)
# B, T, H, W, C = list(truth.shape)
# truth = tf.reshape(tf.transpose(truth, (0,2,1,3,4)), (B*H, T*W, C))
# model = tf.reshape(tf.transpose(model, (0,2,1,3,4)), (B*H, T*W, C))
# return truth, model
class ImagBehavior(tools.Module):
def __init__(self, config, world_model):
self._config = config
self._world_model = world_model
self.actor = networks.ActionHead(
config.num_actions, config.actor_layers, config.units, config.act,
config.actor_dist, config.actor_init_std, config.actor_disc)
self.value = networks.DenseHead(
[], config.value_layers, config.units, config.act)
if config.slow_value_target or config.slow_actor_target:
self._slow_value = networks.DenseHead(
[], config.value_layers, config.units, config.act)
self._updates = tf.Variable(0, tf.int64)
kw = dict(wd=config.weight_decay)
self._actor_opt = tools.Adam(
'actor', config.actor_lr, config.opt_eps, config.actor_grad_clip, **kw)
self._value_opt = tools.Adam(
'value', config.value_lr, config.opt_eps, config.value_grad_clip, **kw)
if config.reuse_task_behavior:
self.expl_value = networks.DenseHead(
[], config.value_layers, config.units, config.act)
if config.slow_value_target or config.slow_actor_target:
self._slow_expl_value = networks.DenseHead(
[], config.value_layers, config.units, config.act)
self._expl_value_opt = tools.Adam(
'expl_value', config.value_lr, config.opt_eps, config.value_grad_clip, **kw)
def train(self, start, objective=None, imagine=None, repeats=None):
assert bool(objective) != bool(imagine)
if self._config.slow_value_target or self._config.slow_actor_target:
if self._updates % self._config.slow_target_update == 0:
if not self._config.slow_target_soft:
for s, d in zip(self.value.variables, self._slow_value.variables):
d.assign(s)
if self._config.reuse_task_behavior:
for s, d in zip(self.expl_value.variables, self._slow_exp_value.variables):
d.assign(s)
else:
alpha = self._config.slow_target_moving_weight
for s, d in zip(self.value.variables, self._slow_value.variables):
d.assign(alpha*s+(1-alpha)*d)
if self._config.reuse_task_behavior:
for s, d in zip(self.expl_value.variables, self._slow_exp_value.variables):
d.assign(alpha*s+(1-alpha)*d)
self._updates.assign_add(1)
if self._config.pred_discount: # Last step could be terminal.
start = {k: v[:, :-1] for k, v in start.items()}
with tf.GradientTape() as actor_tape:
if objective:
imag_feat, imag_state, imag_action = self._imagine(start, repeats)
reward = objective(imag_feat, imag_state, imag_action)
else:
imag_feat, reward = imagine(start)
target, weights = self._compute_target(
imag_feat, reward, self._config.slow_actor_target)
actor_ent = self.actor(
tf.stop_gradient(imag_feat), tf.float32).entropy()
if self._config.actor_entropy:
bonus = self._config.actor_entropy * actor_ent[:-1]
else:
bonus = 0.0
if self._config.reuse_task_behavior:
t = (1-self._config.task_expl_reward_weights) * target['task'] + self._config.task_expl_reward_weights * target['expl']
actor_loss = -tf.reduce_mean(weights * (t + bonus))
else:
actor_loss = -tf.reduce_mean(weights * (target + bonus))
# Value loss
if self._config.slow_value_target != self._config.slow_actor_target:
target, weights = self._compute_target(
imag_feat, reward, self._config.slow_value_target)
if self._config.reuse_task_behavior:
target, expl_target = target['task'], target['expl']
with tf.GradientTape() as value_tape:
like = self.value(imag_feat, tf.float32)[:-1].log_prob(
tf.stop_gradient(target))
value_loss = -tf.reduce_mean(weights * like)
actor_norm = self._actor_opt(actor_tape, actor_loss, [self.actor])
value_norm = self._value_opt(value_tape, value_loss, [self.value])
metrics = dict(
actor_loss=actor_loss, actor_grad_norm=actor_norm,
value_loss=value_loss, value_grad_norm=value_norm,
actor_ent=actor_ent)
if self._config.precision == 16:
metrics['actor_loss_scale'] = self._actor_opt.loss_scale
metrics['value_loss_scale'] = self._value_opt.loss_scale
if self._config.reuse_task_behavior:
with tf.GradientTape() as expl_value_tape:
like = self.expl_value(imag_feat, tf.float32)[:-1].log_prob(
tf.stop_gradient(expl_target))
expl_value_loss = -tf.reduce_mean(weights * like)
expl_value_norm = self._expl_value_opt(expl_value_tape, expl_value_loss, [self.expl_value])
metrics['expl_value_loss'] = expl_value_loss
metrics['expl_value_grad_norm'] = expl_value_norm
if self._config.precision == 16:
metrics['expl_value_loss_scale'] = self._expl_value_opt.loss_scale
return metrics
def _compute_target(self, imag_feat, reward, slow):
if self._config.reuse_task_behavior:
expl_reward = tf.cast(reward['expl'], tf.float32) if tf.float32 else reward['expl']
reward = tf.cast(reward['task'], tf.float32) if tf.float32 else reward['task']
else:
reward = tf.cast(reward, tf.float32) if tf.float32 else reward
if 'discount' in self._world_model.heads:
discount = self._world_model.heads['discount'](
imag_feat, tf.float32).mean()
else:
discount = self._config.discount * tf.ones_like(reward)
if slow:
value = self._slow_value(imag_feat, tf.float32).mode()
else:
value = self.value(imag_feat, tf.float32).mode()
target = tools.lambda_return(
reward[:-1], value[:-1], discount[:-1],
bootstrap=value[-1], lambda_=self._config.discount_lambda, axis=0)
if self._config.reuse_task_behavior:
if slow:
expl_value = self._slow_expl_value(imag_feat, tf.float32).mode()
else:
expl_value = self.expl_value(imag_feat, tf.float32).mode()
expl_target = tools.lambda_return(
expl_reward[:-1], expl_value[:-1], discount[:-1],
bootstrap=expl_value[-1], lambda_=self._config.discount_lambda, axis=0)
target = dict(task=target, expl=expl_target)
weights = tf.stop_gradient(tf.math.cumprod(tf.concat(
[tf.ones_like(discount[:1]), discount[:-2]], 0), 0))
return target, weights
def _imagine(self, start, repeats=None):
if repeats:
# Instead repeats dim after time dim, folded with time dim.
start = {k: tf.repeat(v, repeats, axis=1) for k, v in start.items()}
flatten = lambda x: tf.reshape(x, [-1] + list(x.shape[2:]))
start = {k: flatten(v) for k, v in start.items()}
def step(prev, _):
state, _, _ = prev
feat = self._world_model.dynamics.get_feat(state)
if self._config.dropout_imag:
feat = tf.nn.dropout(feat, self._config.dropout_imag)
action = self.actor(tf.stop_gradient(feat)).sample()
state = self._world_model.dynamics.img_step(state, action)
return state, feat, action
feat = 0 * self._world_model.dynamics.get_feat(start)
action = self.actor(feat).mode()
states, feats, actions = tools.static_scan(
step, tf.range(self._config.imag_horizon),
(start, feat, action))
if repeats:
def unfold(tensor):
s = tensor.shape
return tf.reshape(tensor, [s[0], s[1] // repeats, repeats] + s[2:])
states, feats, actions = tf.nest.map_structure(
unfold, (states, feats, actions))
return feats, states, actions
class RandomBehavior(tools.Module):
def __init__(self, config):
self._config = config
self._float = prec.global_policy().compute_dtype
def actor(self, feat):
shape = feat.shape[:-1] + [self._config.num_actions]
if self._config.actor_dist == 'onehot':
return tools.OneHotDist(tf.zeros(shape))
else:
ones = tf.ones(shape, self._float)
return tfd.Uniform(-ones, ones)
def train(self, start, feat, embed):
return {}
class DisagBehavior(tools.Module):
def __init__(self, config, world_model, behavior):
self._config = config
self._behavior = behavior
self.actor = self._behavior.actor
size = self._config.disag_project or {
'embed': config.cnn_embed_size or 32 * config.cnn_depth,
'stoch': config.stoch_size,
'deter': config.deter_size,
'feat': config.stoch_size + config.deter_size,
}[self._config.disag_target]
kw = dict(
shape=size, layers=config.disag_layers, units=config.disag_units,
act=config.act)
self._networks = [
networks.DenseHead(**kw) for _ in range(config.disag_models)]
if self._config.disag_project:
self._project = tfkl.Dense(self._config.disag_project, trainable=False)
self._opt = tools.Adam(
'disag', config.model_lr, config.opt_eps, config.grad_clip,
config.weight_decay)
def train(self, start, feat, embed):
metrics = {}
target = {
'embed': embed,
'stoch': start['stoch'],
'deter': start['deter'],
'feat': feat,
}[self._config.disag_target]
if self._config.disag_project:
target = self._project(target)
metrics.update(self._train_ensemble(feat, target))
objective = lambda f, s, a: self._intrinsic_reward(f)
metrics.update(self._behavior.train(start, objective))
return metrics
def _intrinsic_reward(self, inputs):
preds = [head(inputs, tf.float32).mean() for head in self._networks]
disag = tf.reduce_mean(tf.math.reduce_std(preds, 0), -1)
if self._config.disag_log:
disag = tf.math.log(disag)
return self._config.disag_scale * disag
def _train_ensemble(self, inputs, targets):
if self._config.disag_offset:
targets = targets[:, self._config.disag_offset:]
inputs = inputs[:, :-self._config.disag_offset]
targets = tf.stop_gradient(targets)
inputs = tf.stop_gradient(inputs)
with tf.GradientTape() as tape:
preds = [head(inputs) for head in self._networks]
likes = [tf.reduce_mean(pred.log_prob(targets)) for pred in preds]
loss = -tf.reduce_sum(likes)
norm = self._opt(tape, loss, self._networks)
metrics = dict(disag_loss=loss, disag_grad_norm=norm)
if self._config.precision == 16:
metrics['disag_loss_scale'] = self._opt.loss_scale
return metrics
class DisagInfoBehavior(tools.Module):
def __init__(self, config, world_model, behavior):
self._config = config
self._behavior = behavior
self.actor = self._behavior.actor
size = self._config.disag_project or {
'embed': config.cnn_embed_size or 32 * config.cnn_depth,
'stoch': config.stoch_size,
'deter': config.deter_size,
'feat': config.stoch_size + config.deter_size,
}[self._config.disag_target]
kw = dict(
shape=size, layers=config.disag_layers,
units=config.disag_units, act=config.act, std='learned')
self._networks = [
networks.DenseHead(**kw) for _ in range(config.disag_models)]
if self._config.disag_project:
self._project = tfkl.Dense(self._config.disag_project, trainable=False)
self._opt = tools.Adam(
'disag', config.model_lr, config.opt_eps, config.grad_clip,
config.weight_decay)
self._float = prec.global_policy().compute_dtype
def train(self, start, feat, embed):
metrics = {}
target = {
'embed': embed,
'stoch': start['stoch'],
'deter': start['deter'],
'feat': feat,
}[self._config.disag_target]
if self._config.disag_project:
target = self._project(target)
metrics.update(self._train_ensemble(feat, target))
objective = lambda f, s, a: self._intrinsic_reward(f)
metrics.update(self._behavior.train(start, objective))
return metrics
def _intrinsic_reward(self, inputs):
S = self._config.disag_samples
preds = [head(inputs, tf.float32) for head in self._networks]
cond = -tf.reduce_mean([
pred.log_prob(pred.sample(S)) for pred in preds], [0, 1])
if self._config.disag_save_memory:
cond = self._low_memory_cond_entropy(preds)
marg = self._low_memory_marg_entropy(preds)
else:
cond = -tf.reduce_mean([
pred.log_prob(pred.sample(S)) for pred in preds], [0, 1])
H, BT, F = list(preds[0].mean().shape)
K = len(self._networks)
weights = tfd.Categorical(tf.zeros((H, BT, K), self._float))
mixture = tfd.Mixture(weights, preds)
marg = -tf.reduce_mean(mixture.log_prob(mixture.sample(S)), 0)
return self._config.disag_scale * (marg - cond)
def _train_ensemble(self, inputs, targets):
if self._config.disag_offset:
targets = targets[:, self._config.disag_offset:]
inputs = inputs[:, :-self._config.disag_offset]
targets = tf.stop_gradient(tf.cast(targets, tf.float32))
inputs = tf.stop_gradient(inputs)
with tf.GradientTape() as tape:
preds = [head(inputs, tf.float32) for head in self._networks]
likes = [tf.reduce_mean(pred.log_prob(targets)) for pred in preds]
loss = -tf.reduce_sum(likes)
norm = self._opt(tape, loss, self._networks)
return dict(disag_loss=loss, disag_grad_norm=norm)
def _low_memory_cond_entropy(self, preds):
entropies = []
for pred in preds:
logprobs = []
for _ in range(self._config.disag_samples):
logprobs.append(pred.log_prob(pred.sample()))
entropies.append(-tf.reduce_mean(logprobs, 0))
return tf.reduce_mean(entropies, 0)
def _low_memory_marg_entropy(self, preds):
H, BT, F = list(preds[0].mean().shape)
K = len(self._networks)
logprobs = []
for _ in range(self._config.disag_samples):
sample = tf.zeros((H, BT, F), self._float)
index = tfd.Categorical(tf.zeros((H, BT, K), self._float)).sample()
for i, p in enumerate(preds):
sample += tf.cast(index == i, self._float)[:, :, None] * p.sample()
prob = tf.zeros((H, BT), self._float)
for p in preds:
prob += p.prob(sample)
prob /= K
logprobs.append(tf.math.log(prob + 1e-8))
return -tf.reduce_mean(logprobs, 0)
class DisagNoiseBehavior(tools.Module):
def __init__(self, config, world_model, behavior):
self._config = config
self._behavior = behavior
self.actor = self._behavior.actor
size = self._config.disag_project or {
'embed': config.cnn_embed_size or 32 * config.cnn_depth,
'stoch': config.stoch_size,
'deter': config.deter_size,
'feat': config.stoch_size + config.deter_size,
}[self._config.disag_target]
kw = dict(
shape=size, layers=config.disag_layers,
units=config.disag_units, act=config.act, std='learned')
self._networks = [
networks.DenseHead(**kw) for _ in range(config.disag_models)]
if self._config.disag_project:
self._project = tfkl.Dense(self._config.disag_project, trainable=False)
self._opt = tools.Adam(
'disag', config.model_lr, config.opt_eps, config.grad_clip,
config.weight_decay)
self._float = prec.global_policy().compute_dtype
def train(self, start, feat, embed):
metrics = {}
target = {
'embed': embed,
'stoch': start['stoch'],
'deter': start['deter'],
'feat': feat,
}[self._config.disag_target]
if self._config.disag_project:
target = self._project(target)
metrics.update(self._train_ensemble(feat, target))
objective = lambda f, s, a: self._intrinsic_reward(f)
metrics.update(self._behavior.train(start, objective))
return metrics
def _intrinsic_reward(self, inputs):
preds = [head(inputs, tf.float32) for head in self._networks]
means = [pred.mean() for pred in preds]
marg = tf.math.log(tf.reduce_mean(tf.math.reduce_std(means, 0), -1))
cond = tf.reduce_mean([pred.entropy() for pred in preds], 0)
cond *= self._config.disag_noise_scale
return self._config.disag_scale * (marg - cond)
def _train_ensemble(self, inputs, targets):
if self._config.disag_offset:
targets = targets[:, self._config.disag_offset:]
inputs = inputs[:, :-self._config.disag_offset]
targets = tf.stop_gradient(tf.cast(targets, tf.float32))
inputs = tf.stop_gradient(inputs)
with tf.GradientTape() as tape:
preds = [head(inputs, tf.float32) for head in self._networks]
likes = [tf.reduce_mean(pred.log_prob(targets)) for pred in preds]
loss = -tf.reduce_sum(likes)
norm = self._opt(tape, loss, self._networks)
return dict(disag_loss=loss, disag_grad_norm=norm)
class EmpowOpenBehavior(tools.Module): # I(S; A), open-loop, action space
def __init__(self, config, world_model, behavior):
self._config = config
self._world_model = world_model
self._behavior = behavior
self.actor = self._behavior.actor
def train(self, start, feat, embed, task_reward=None):
metrics = self._behavior.train(start, imagine=lambda s: self._imagine(s, task_reward))
return metrics
def _imagine(self, start, task_reward=None):
parallel = 1 + self._config.empow_sequences
flatten = lambda x: tf.reshape(x, [-1] + list(x.shape[2:]))
start = {k: flatten(v) for k, v in start.items()}
start = {k: tf.tile(v, [parallel, 1]) for k, v in start.items()}
def step(prev, _):
state, _, _ = prev
feat = self._world_model.dynamics.get_feat(state)
if self._config.dropout_imag:
feat = tf.nn.dropout(feat, self._config.dropout_imag)
# action = self.actor(tf.stop_gradient(
# tf.split(feat, parallel, 0)[0])).sample()
# state = self._world_model.dynamics.img_step(
# state, tf.tile(action, [parallel, 1]))
action = self.actor(tf.stop_gradient(feat)).sample()
state = self._world_model.dynamics.img_step(state, action)
return state, feat, action
range_ = tf.range(self._config.imag_horizon)
feat = 0 * self._world_model.dynamics.get_feat(start)
action = 0 * self.actor(feat).mean()
_, feat, action = tools.static_scan(step, range_, (start, feat, action))
shape = (feat.shape[0], parallel, feat.shape[1] // parallel, feat.shape[2])
feat = tf.transpose(tf.reshape(feat, shape), [0, 2, 1, 3])
# feedback_feat = feat[:, :, :1] # Time, batch, 1, features.
# openloop_feat = feat[:, :, 1:] # Time, batch, mixture, features.
feedback_feat = []
openloop_feat = []
for i in range(parallel):
feedback_feat.append(feat[:, :, i, :])
openloop_feat.append(tf.concat([feat[:, :, :i, :], feat[:, :, i+1:, :]], axis=2))
feedback_feat = tf.stack(feedback_feat) # Parallel, Time, batch, features.
openloop_feat = tf.stack(openloop_feat) # Parallel, Time, batch, mixture, features.
reward = self._intrinsic_reward(feedback_feat, openloop_feat, action)
feedback_feat = tf.transpose(feedback_feat, [1,0,2,3])
feedback_feat = tf.reshape(feedback_feat, [feedback_feat.shape[0], -1, feedback_feat.shape[3]])
reward = tf.transpose(reward, [1,0,2])
reward = tf.reshape(reward, [reward.shape[0], -1])
if task_reward:
t_reward = task_reward(feedback_feat, 0, 0)
reward = dict(expl=reward, task=t_reward)
return feedback_feat, reward
def _intrinsic_reward(self, feedback_feat, openloop_feat, action):
fb_dist = self.actor(feedback_feat, dtype=tf.float32)
ol_dist = tools.uniform_mixture(
self.actor(openloop_feat, dtype=tf.float32), dtype=tf.float32)
if self._config.empow_entropy == 'sample':
S = self._config.empow_samples
cond = -tf.reduce_mean(fb_dist.log_prob(fb_dist.sample(S)), 0)
marg = -tf.reduce_mean(ol_dist.log_prob(ol_dist.sample(S)), 0)
elif self._config.empow_entropy == 'analytic':
cond = fb_dist.entropy()
if self._config.empow_sequences == 1:
marg = ol_dist.entropy()
elif ol_dist.dtype is tf.int32:
marg = tools.cat_mixture_entropy(ol_dist)
else:
raise NotImplementedError(self._config.empow_entropy)
assert marg.shape == cond.shape == feedback_feat.shape[0:3]
return self._config.empow_scale * (marg - cond)
class EmpowActionBehavior(tools.Module): # I(S; A), closed-loop, action space
def __init__(self, config, world_model, behavior):
self._config = config
self._world_model = world_model
self._behavior = behavior
self.actor = self._behavior.actor
def train(self, start, feat, embed, task_reward=None):
repeats = dict(
sequences=self._config.empow_sequences, batch=None, time=None,
)[self._config.empow_marginal]
metrics = self._behavior.train(
start, lambda f, s, a: self._intrinsic_reward(f, s, a, task_reward=task_reward), repeats=repeats)
return metrics
def _intrinsic_reward(self, feat, state, action, task_reward=None):
if self._config.empow_marginal == 'sequences':
pass # Trajectory has trailing repeat dimension.
elif self._config.empow_marginal == 'batch':
pass # Trajectory has trailing batch dimension.
elif self._config.empow_marginal == 'time':
feat = tf.transpose(feat, [1, 0, 2])
state = {k: tf.transpose(v, [1, 0, 2]) for k, v in state.items()}
dist = self.actor(feat, dtype=tf.float32)
mixture = tools.uniform_mixture(dist, dtype=tf.float32)
if self._config.empow_entropy == 'sample':
S = self._config.empow_samples
cond = -tf.reduce_mean(dist.log_prob(dist.sample(S)), 0)
marg = -tf.reduce_mean(mixture.log_prob(mixture.sample(S)), 0)
elif self._config.empow_entropy == 'analytic':
assert dist.dtype is tf.int32
cond = dist.entropy()
marg = tools.cat_mixture_entropy(dist)
else:
raise NotImplementedError(self._config.empow_entropy)
marg = tf.repeat(marg[..., None], feat.shape[-2], -1)
assert marg.shape == cond.shape == feat.shape[:-1]
if self._config.empow_marginal == 'time':
marg = tf.transpose(marg, [1, 0])
cond = tf.transpose(cond, [1, 0])
reward = self._config.empow_scale * (marg - cond)
if task_reward:
t_reward = task_reward(feat, 0, 0)
reward = dict(expl=reward, task=t_reward)
return reward
class EmpowStateBehavior(tools.Module): # I(S; A), closed-loop, state space
def __init__(self, config, world_model, behavior):
self._config = config
self._world_model = world_model
self._behavior = behavior
self.actor = self._behavior.actor
def train(self, start, feat, embed, task_reward=None):
repeats = dict(
sequences=self._config.empow_sequences, batch=None, time=None,
)[self._config.empow_marginal]
metrics = self._behavior.train(
start, lambda f, s, a: self._intrinsic_reward(f, s, a, task_reward), repeats=repeats)
return metrics
def _intrinsic_reward(self, feat, state, action, task_reward=None):
if self._config.empow_marginal == 'sequences':
pass # Trajectory has trailing repeat dimension.
elif self._config.empow_marginal == 'batch':
pass # Trajectory has trailing batch dimension.
elif self._config.empow_marginal == 'time':
feat = tf.transpose(feat, [1, 0, 2])
state = {k: tf.transpose(v, [1, 0, 2]) for k, v in state.items()}
dist = self._world_model.dynamics.get_dist(state, dtype=tf.float32)
mixture = tools.uniform_mixture(dist, dtype=tf.float32)
S = self._config.empow_samples
cond = -tf.reduce_mean(dist.log_prob(dist.sample(S)), 0)
marg = -tf.reduce_mean(mixture.log_prob(mixture.sample(S)), 0)
marg = tf.repeat(marg[..., None], feat.shape[-2], -1)
assert marg.shape == cond.shape == feat.shape[:-1]
if self._config.empow_marginal == 'time':
marg = tf.transpose(marg, [1, 0])
cond = tf.transpose(cond, [1, 0])
reward = self._config.empow_scale * (marg - cond)
if task_reward:
t_reward = task_reward(feat, 0, 0)
reward = dict(expl=reward, task=t_reward)
return reward
class EmpowStepBehavior(tools.Module): # I(S'; A), state space
def __init__(self, config, world_model, behavior):
self._config = config
self._world_model = world_model
self._behavior = behavior
self.actor = self._behavior.actor
def train(self, start, feat, embed, task_reward=None):
metrics = self._behavior.train(start, lambda f, s, a: self._intrinsic_reward(f, s, a, task_reward))
return metrics
def _intrinsic_reward(self, feat, state, action, task_reward=None):
T, B, _ = feat.shape
K = self._config.empow_sequences
state = {k: tf.repeat(v, K, axis=1) for k, v in state.items()}
action = self.actor(self._world_model.dynamics.get_feat(state)).sample()
state = self._world_model.dynamics.img_step(state, action)
state = {k: tf.reshape(v, (T,B,K,v.shape[-1])) for k, v in state.items()}
dist = self._world_model.dynamics.get_dist(state, dtype=tf.float32)
mixture = tools.uniform_mixture(dist, dtype=tf.float32)
S = self._config.empow_samples
cond = -tf.reduce_mean(dist.log_prob(dist.sample(S)), (0, -1))
marg = -tf.reduce_mean(mixture.log_prob(mixture.sample(S)), 0)
reward = self._config.empow_scale * (marg - cond)
if task_reward:
t_reward = task_reward(feat, 0, 0)
reward = dict(expl=reward, task=t_reward)
return reward
class EmpowVIMBehavior(tools.Module): # I(S';A), action space
def __init__(self, config, world_model, behavior):
self._config = config
self._world_model = world_model
self._behavior = behavior
self.actor = self._behavior.actor
self._tau = config.temperature # temperature
self._plan_ss = networks.ActionHead(
config.num_actions, config.actor_layers, config.units, config.act,
config.actor_dist, config.actor_init_std, config.actor_disc) # p(a_0|s, s')
self._plan_sas = networks.ActionHead(
config.num_actions, config.actor_layers, config.units, config.act,
config.actor_dist, config.actor_init_std, config.actor_disc) # p(a_t|s, a_t-1, s')
self._source = networks.ActionHead(
config.num_actions, config.actor_layers, config.units, config.act,
config.actor_dist, config.actor_init_std, config.actor_disc)
self._offset = networks.DenseHead(
[], config.actor_layers, config.units, config.act)
self._plan_opt = tools.Adam(
'plan', config.actor_lr, config.opt_eps, config.grad_clip,
config.weight_decay)
self._source_opt = tools.Adam(
'source', config.actor_lr, config.opt_eps, config.grad_clip,
config.weight_decay) # for both source and offset
def train(self, start, feat, embed, task_reward=None):
metrics = {}
metrics.update(self._behavior.train(start, lambda f, s, a: self._intrinsic_reward(f, s, a, task_reward)))
metrics.update(self._train_vim(start))
return metrics
def _intrinsic_reward(self, feat, state, action, task_reward=None):
reward = self._tau * self._config.vim_horizon * self._offset(feat).mode()
if task_reward:
t_reward = task_reward(feat, 0, 0)
reward = dict(expl=reward, task=t_reward)
return reward
def _imagine(self, start): # don't use this for self._behavior.train
flatten = lambda x: tf.reshape(x, [-1] + list(x.shape[2:]))
start = {k: flatten(v) for k, v in start.items()}
def step(prev, _):
state, _, _ = prev
feat = self._world_model.dynamics.get_feat(state)
if self._config.dropout_imag:
feat = tf.nn.dropout(feat, self._config.dropout_imag)
action = self._source(tf.stop_gradient(feat)).sample() # source instead of actor
state = self._world_model.dynamics.img_step(state, action)
return state, feat, action
range_ = tf.range(self._config.imag_horizon)
feat = 0 * self._world_model.dynamics.get_feat(start)
action = 0 * self._source(feat).mean()
_, feat, action = tools.static_scan(step, range_, (start, feat, action))
return feat, action
def _train_vim(self, start):
feat, action = self._imagine(start)
h = self._config.vim_horizon
acts = []
for i in range(len(action)-h):
acts.append(action[i:i+h])
acts = tf.stack(acts) # time-horizon, horizon, batch, features
ss_feat = tf.concat([feat[:-h], feat[h:]], axis=-1) # time-horizon, batch, features
if h > 1:
sas_feat = tf.tile(tf.expand_dims(ss_feat, axis=1), [1, h-1, 1, 1])
sas_feat = tf.concat([sas_feat, acts[:, :-1]], axis=-1) # time-horizon, horizon-1, batch, features
action, acts = tf.stop_gradient(tf.cast(action, tf.float32)), tf.stop_gradient(tf.cast(acts, tf.float32))
metrics = {}
with tf.GradientTape() as tape:
plan_ss_pred = self._plan_ss(ss_feat, dtype=tf.float32)
log_plan_ss_prob = plan_ss_pred.log_prob(acts[:, 0])
loss = -tf.reduce_mean(log_plan_ss_prob)
if h > 1:
plan_sas_pred = self._plan_sas(sas_feat, dtype=tf.float32)
log_plan_sas_prob = tf.reduce_sum(plan_sas_pred.log_prob(acts[:, 1:]), axis=1)
loss += -tf.reduce_mean(log_plan_sas_prob)
loss *= 1./h
if h > 1:
norm = self._plan_opt(tape, loss, [self._plan_ss, self._plan_sas])
else:
norm = self._plan_opt(tape, loss, [self._plan_ss])
metrics['plan_grad_norm'] = norm
metrics['plan_loss'] = loss
if self._config.precision == 16:
metrics['plan_dynamics_loss_scale'] = self._plan_opt.loss_scale
with tf.GradientTape() as tape:
offset = self._offset(feat[:-h], dtype=tf.float32).mode()
action_pred = self._source(feat[:-h], dtype=tf.float32)
r_sa = action_pred.log_prob(action[:-h]) + offset
if h > 1:
log_plan_prob = log_plan_ss_prob + log_plan_sas_prob
else:
log_plan_prob = log_plan_ss_prob
loss = tf.reduce_mean(tf.square(1./self._tau * 1./h * tf.stop_gradient(log_plan_prob) - r_sa))
norm = self._source_opt(tape, loss, [self._source, self._offset])
metrics['source_grad_norm'] = norm
metrics['source_loss'] = loss
if self._config.precision == 16:
metrics['source_loss_scale'] = self._source_opt.loss_scale
return metrics
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment