Skip to content

Instantly share code, notes, and snippets.

@hartikainen
Last active March 28, 2020 23:36
Show Gist options
  • Save hartikainen/6be9ec29c47b2c28df1a7e21ae6e73fe to your computer and use it in GitHub Desktop.
Save hartikainen/6be9ec29c47b2c28df1a7e21ae6e73fe to your computer and use it in GitHub Desktop.
Policy evaluation with bilevel optimization
import argparse
import datetime
import itertools
import os
import numpy as np
import ray
from ray import tune
import tensorflow as tf
import tree
from .visualization import visualize, visualize_experiment
class TsitsiklisTriangle:
"""Tsitsiklis Triangle counter example for off-policy divergence.
For details, see [1].
References:
[1] Tsitsiklis, John N., and Benjamin Van Roy.
"Analysis of temporal-difference learning with function approximation."
Advances in neural information processing systems. 1997.
"""
def __init__(self):
"""Tsitsiklis Triangle counter example MDP."""
self.states = np.arange(3, dtype=np.int32)
self.actions = np.zeros(1, dtype=np.int32)
self.s_0_probabilities = np.ones(3) / 3
self.rewards = np.zeros((3, 1, 3))
self.transition_probabilities = np.array((
((0.5, 0.0, 0.5), ),
((0.5, 0.5, 0.0), ),
((0.0, 0.5, 0.5), ),
))
self.optimal_V = np.zeros(3)
self.all_transitions = self._generate_all_transitions()
def _generate_all_transitions(self):
states_0, actions, states_1 = np.where(
0 < self.transition_probabilities[self.states])
rewards = self.rewards[states_0, actions, states_1]
transitions = {
'state_0': states_0,
'action': actions,
'state_1': states_1,
'reward': rewards,
}
return transitions
def sample(self):
num_transitions = tree.flatten(self.all_transitions)[0].shape[0]
random_index = np.random.randint(0, num_transitions)
sample = tree.map_structure(
lambda x: x[random_index], self.all_transitions)
return sample
class ValueFunction(tf.keras.Model):
def __init__(self, initial_weight=0.0, epsilon=5e-2):
super(ValueFunction, self).__init__()
self._initial_weight = initial_weight
self._epsilon = epsilon
self.weight = self.add_weight(
'weight',
shape=[1],
initializer=tf.initializers.constant(initial_weight))
def call(self, inputs):
tf.debugging.assert_shapes([[inputs, (None, 1)]])
x = (tf.sqrt(3.0) * self.weight) / 2.0
V = tf.exp(self._epsilon * self.weight) * tf.stack((
tf.sqrt(3.0) * tf.sin(x) - tf.cos(x),
- tf.sqrt(3.0) * tf.sin(x) - tf.cos(x),
2.0 * tf.cos(x),
))
result = tf.gather_nd(V, inputs)
tf.debugging.assert_equal(tf.rank(result), 2)
tf.debugging.assert_equal(tf.shape(result), tf.shape(inputs))
tf.debugging.assert_equal(tf.shape(result)[1], 1)
return result
def get_config(self):
config = {
'initial_weight': self._initial_weight,
'epsilon': self._epsilon,
}
return config
class TD0:
def __init__(self, V, alpha=2e-3, gamma=0.9):
self.V = V
self._alpha = alpha
self._gamma = gamma
def update_V(self, state_0, action, state_1, reward):
V_s1 = self.V(np.atleast_2d(state_1))[0]
target = reward + self._gamma * V_s1
with tf.GradientTape() as tape:
V_s0 = self.V(np.atleast_2d(state_0))[0]
grad_V = tape.gradient(V_s0, self.V.weight)
delta = V_s0 - target
omega_0 = self.V.weight.value()
omega_1 = omega_0 + self._alpha * delta * grad_V
self.V.weight.assign(omega_1)
class Bilevel:
def __init__(self, V_omega, alpha=5e-2, beta=5e-1, gamma=0.9):
self.V_omega = V_omega
self.V_theta = type(V_omega)(**V_omega.get_config())
self._alpha = alpha
self._beta = beta
self._gamma = gamma
@property
def V(self):
return self.V_omega
def update_V(self, state_0, action, state_1, reward):
theta_0 = self.V_theta.weight.value()
omega_0 = self.V_omega.weight.value()
V_omega_s1 = self.V_omega(np.atleast_2d(state_1))[0]
target = reward + self._gamma * V_omega_s1
with tf.GradientTape() as tape:
V_theta_s0 = self.V_theta(np.atleast_2d(state_0))[0]
grad_V_theta = tape.gradient(V_theta_s0, self.V_theta.weight)
delta = V_theta_s0 - target
theta_1 = theta_0 - self._beta * delta * grad_V_theta
self.V_theta.weight.assign(theta_1)
omega_1 = (1 - self._alpha) * omega_0 + self._alpha * theta_0
self.V_omega.weight.assign(omega_1)
def verify_spiral_sample(state_0, action, state_1, reward):
assert 0 <= state_0 and state_0 < 3, state_0
assert 0 <= state_1 and state_1 < 3, state_1
assert state_1 != ((state_0 + 1) % 3)
assert reward == 0, reward
class ExperimentRunner(tune.Trainable):
def _setup(self, variant):
# Set the current working directory such that the local mode
# logs into the correct place. This would not be needed on
# local/cluster mode.
if ray.worker._mode() == ray.worker.LOCAL_MODE:
os.chdir(os.getcwd())
np.random.seed(np.random.randint(0, 10000))
method_variant = variant['method']
method_class = method_variant['class']
method_config = method_variant['config']
self.method = method_class(ValueFunction(), **method_config)
self.environment = TsitsiklisTriangle()
def _train(self):
for step in range(self.config['epoch_length']):
sample = self.environment.sample()
verify_spiral_sample(**sample) # Just a sanity check
self.method.update_V(**sample)
current_V = self.method.V(self.environment.states[..., None])
MSE = tf.losses.MSE(
y_pred=current_V[:, 0], y_true=self.environment.optimal_V)
result = {
'MSE': MSE.numpy(),
'weight': self.method.V.weight.numpy().item(),
'training_steps': (
(self.iteration + 1) * self.config['epoch_length'])
}
result['done'] = self.config['num_steps'] < result['training_steps']
return result
def train(num_samples,
num_steps, # Total training steps
epoch_length): # Log interval
LEARNING_RATES = (1e-3, 1e-2, 1e-1, 1.0) # Learning rates to sweep over
GAMMA = 0.9 # Discount
METHOD_VARIANTS = [
{
'class': TD0,
'class_name': tune.sample_from(
lambda spec: spec['config']['method']['class'].__name__),
'config': {
'alpha': alpha,
'gamma': GAMMA,
},
} for alpha in LEARNING_RATES
] + [
{
'class': Bilevel,
'class_name': tune.sample_from(
lambda spec: spec['config']['method']['class'].__name__),
'config': {
'alpha': alpha,
'beta': beta,
'gamma': GAMMA,
},
}
for alpha, beta in itertools.product(LEARNING_RATES, repeat=2)
# if alpha < beta
]
datetime_stamp = datetime.datetime.now().strftime(
'%Y{d}%m{d}%dT%H{d}%M{d}%S'''.format(d="-"))
# Set this to `local_mode=True` if you want to debug with breakpoints.
ray.init(local_mode=False)
result = tune.run(
ExperimentRunner,
name=datetime_stamp,
config={
'epoch_length': epoch_length,
'num_steps': num_steps,
'method': tune.grid_search(METHOD_VARIANTS),
},
num_samples=num_samples,
local_dir=os.path.join(os.getcwd(), 'data'),
checkpoint_freq=0,
checkpoint_at_end=False,
max_failures=0,
with_server=False,
loggers=(ray.tune.logger.CSVLogger, ray.tune.logger.JsonLogger),
)
visualize(result)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument(
'--mode',
type=str,
choices=('train', 'visualize'),
default='visualize')
parser.add_argument('--num-samples', type=int, default=25)
parser.add_argument('--num-steps', type=int, default=1000)
parser.add_argument('--epoch-length', type=int, default=25)
parser.add_argument('--experiment-path', type=str, default=None)
args = parser.parse_args()
if args.mode == 'train':
train(num_samples=args.num_samples,
num_steps=args.num_steps,
epoch_length=args.epoch_length)
elif args.mode == 'visualize':
if args.experiment_path is None:
raise ValueError("Set '--experiment-path [path-to-experiment]'.")
visualize_experiment(args.experiment_path)
dm-tree==0.1.2
matplotlib==3.2.1
pandas==1.0.1
ray[tune]==0.8.3
seaborn==0.10.0
tensorflow==2.1.0
import numpy as np
import tree
from .main import TsitsiklisTriangle
tsitsiklis_triangle = TsitsiklisTriangle()
all_transitions = tsitsiklis_triangle._generate_all_transitions()
samples = tree.map_structure(
lambda *x: np.stack(x),
*[tsitsiklis_triangle.sample() for i in range(1000)])
states_0, counts_0 = np.unique(samples['state_0'], return_counts=True)
np.testing.assert_equal(states_0, (0, 1, 2))
assert all(300 < counts_0)
states_1, counts_1 = np.unique(samples['state_1'], return_counts=True)
np.testing.assert_equal(states_1, (0, 1, 2))
assert all(300 < counts_1)
states_0_1 = np.concatenate((
samples['state_0'][..., None],
samples['state_1'][..., None]
), axis=1)
assert np.all(samples['state_1'] != ((samples['state_0'] + 1) % 3))
np.testing.assert_equal(samples['action'], 0)
np.testing.assert_equal(samples['reward'], 0)
import tensorflow as tf
class ValueFunctionV2(tf.keras.Model):
def __init__(self, initial_weight=0.0, epsilon=5e-2):
super(ValueFunctionV2, self).__init__()
self._initial_weight = initial_weight
self._epsilon = epsilon
self.weight = self.add_weight(
'weight',
shape=[1],
initializer=tf.initializers.constant(initial_weight))
def call(self, inputs):
tf.debugging.assert_shapes([[inputs, (None, 1)]])
lambda_ = tf.sqrt(3.0) / 2.0
A = tf.constant(([100.0], [-70.0], [-30.0]))
B = tf.constant(([23.094], [-98.15], [75.056]))
V = tf.exp(self._epsilon * self.weight) * (
A * tf.cos(lambda_ * self.weight)
- B * tf.sin(lambda_ * self.weight))
result = tf.gather_nd(V, inputs)
tf.debugging.assert_equal(tf.rank(result), 2)
tf.debugging.assert_equal(tf.shape(result), tf.shape(inputs))
tf.debugging.assert_equal(tf.shape(result)[1], 1)
return result
def get_config(self):
config = {
'initial_weight': self._initial_weight,
'epsilon': self._epsilon,
}
return config
import os
import numpy as np
from ray import tune
import tree
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
def visualize(result):
# sns.set_style(style='dark')
trial_dataframes = result.fetch_trial_dataframes()
trial_configs = result.get_all_configs()
def create_visualization_dataframe(dataframe, config):
label = (
f"{config['method']['class_name']}: "
+ ", ".join(f"{key}={value:.0e}"
for key, value
in sorted(config['method']['config'].items())))
for greek_alphabet in ("alpha", "beta", "gamma"):
label = label.replace(greek_alphabet, f"$\{greek_alphabet}$")
dataframe['label'] = label
return dataframe[['MSE', 'weight', 'label', 'training_steps']]
visualization_dataframes = tree.map_structure_up_to(
{x: None for x in trial_dataframes},
create_visualization_dataframe, trial_dataframes, trial_configs)
visualization_dataframe = pd.concat(visualization_dataframes.values())
default_figsize = plt.rcParams.get('figure.figsize')
figsize = np.array((2, 1)) * np.max(default_figsize[0])
figure, axes = plt.subplots(1, 2, figsize=figsize)
sns.lineplot(
x='training_steps',
y='MSE',
hue='label',
data=visualization_dataframe,
legend=False,
ax=axes[0])
sns.lineplot(
x='training_steps',
y='weight',
hue='label',
data=visualization_dataframe,
legend='brief',
ax=axes[1])
axes[0].set_ylim(np.clip(axes[0].get_ylim(), -0.01, 5.0))
axes[1].set_yscale('symlog')
legend = axes[1].legend(
loc='center left',
bbox_to_anchor=(1.05, 0.5),
ncol=1,
borderaxespad=0.0)
plt.savefig(
os.path.join(result._experiment_dir, 'result.pdf'),
bbox_extra_artists=(legend, ),
bbox_inches='tight')
plt.savefig(
os.path.join(result._experiment_dir, 'result.png'),
bbox_extra_artists=(legend, ),
bbox_inches='tight')
def visualize_experiment(experiment_path):
result = tune.Analysis(experiment_path)
visualize(result)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment