Last active
March 28, 2020 23:36
-
-
Save hartikainen/6be9ec29c47b2c28df1a7e21ae6e73fe to your computer and use it in GitHub Desktop.
Policy evaluation with bilevel optimization
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import argparse | |
import datetime | |
import itertools | |
import os | |
import numpy as np | |
import ray | |
from ray import tune | |
import tensorflow as tf | |
import tree | |
from .visualization import visualize, visualize_experiment | |
class TsitsiklisTriangle: | |
"""Tsitsiklis Triangle counter example for off-policy divergence. | |
For details, see [1]. | |
References: | |
[1] Tsitsiklis, John N., and Benjamin Van Roy. | |
"Analysis of temporal-difference learning with function approximation." | |
Advances in neural information processing systems. 1997. | |
""" | |
def __init__(self): | |
"""Tsitsiklis Triangle counter example MDP.""" | |
self.states = np.arange(3, dtype=np.int32) | |
self.actions = np.zeros(1, dtype=np.int32) | |
self.s_0_probabilities = np.ones(3) / 3 | |
self.rewards = np.zeros((3, 1, 3)) | |
self.transition_probabilities = np.array(( | |
((0.5, 0.0, 0.5), ), | |
((0.5, 0.5, 0.0), ), | |
((0.0, 0.5, 0.5), ), | |
)) | |
self.optimal_V = np.zeros(3) | |
self.all_transitions = self._generate_all_transitions() | |
def _generate_all_transitions(self): | |
states_0, actions, states_1 = np.where( | |
0 < self.transition_probabilities[self.states]) | |
rewards = self.rewards[states_0, actions, states_1] | |
transitions = { | |
'state_0': states_0, | |
'action': actions, | |
'state_1': states_1, | |
'reward': rewards, | |
} | |
return transitions | |
def sample(self): | |
num_transitions = tree.flatten(self.all_transitions)[0].shape[0] | |
random_index = np.random.randint(0, num_transitions) | |
sample = tree.map_structure( | |
lambda x: x[random_index], self.all_transitions) | |
return sample | |
class ValueFunction(tf.keras.Model): | |
def __init__(self, initial_weight=0.0, epsilon=5e-2): | |
super(ValueFunction, self).__init__() | |
self._initial_weight = initial_weight | |
self._epsilon = epsilon | |
self.weight = self.add_weight( | |
'weight', | |
shape=[1], | |
initializer=tf.initializers.constant(initial_weight)) | |
def call(self, inputs): | |
tf.debugging.assert_shapes([[inputs, (None, 1)]]) | |
x = (tf.sqrt(3.0) * self.weight) / 2.0 | |
V = tf.exp(self._epsilon * self.weight) * tf.stack(( | |
tf.sqrt(3.0) * tf.sin(x) - tf.cos(x), | |
- tf.sqrt(3.0) * tf.sin(x) - tf.cos(x), | |
2.0 * tf.cos(x), | |
)) | |
result = tf.gather_nd(V, inputs) | |
tf.debugging.assert_equal(tf.rank(result), 2) | |
tf.debugging.assert_equal(tf.shape(result), tf.shape(inputs)) | |
tf.debugging.assert_equal(tf.shape(result)[1], 1) | |
return result | |
def get_config(self): | |
config = { | |
'initial_weight': self._initial_weight, | |
'epsilon': self._epsilon, | |
} | |
return config | |
class TD0: | |
def __init__(self, V, alpha=2e-3, gamma=0.9): | |
self.V = V | |
self._alpha = alpha | |
self._gamma = gamma | |
def update_V(self, state_0, action, state_1, reward): | |
V_s1 = self.V(np.atleast_2d(state_1))[0] | |
target = reward + self._gamma * V_s1 | |
with tf.GradientTape() as tape: | |
V_s0 = self.V(np.atleast_2d(state_0))[0] | |
grad_V = tape.gradient(V_s0, self.V.weight) | |
delta = V_s0 - target | |
omega_0 = self.V.weight.value() | |
omega_1 = omega_0 + self._alpha * delta * grad_V | |
self.V.weight.assign(omega_1) | |
class Bilevel: | |
def __init__(self, V_omega, alpha=5e-2, beta=5e-1, gamma=0.9): | |
self.V_omega = V_omega | |
self.V_theta = type(V_omega)(**V_omega.get_config()) | |
self._alpha = alpha | |
self._beta = beta | |
self._gamma = gamma | |
@property | |
def V(self): | |
return self.V_omega | |
def update_V(self, state_0, action, state_1, reward): | |
theta_0 = self.V_theta.weight.value() | |
omega_0 = self.V_omega.weight.value() | |
V_omega_s1 = self.V_omega(np.atleast_2d(state_1))[0] | |
target = reward + self._gamma * V_omega_s1 | |
with tf.GradientTape() as tape: | |
V_theta_s0 = self.V_theta(np.atleast_2d(state_0))[0] | |
grad_V_theta = tape.gradient(V_theta_s0, self.V_theta.weight) | |
delta = V_theta_s0 - target | |
theta_1 = theta_0 - self._beta * delta * grad_V_theta | |
self.V_theta.weight.assign(theta_1) | |
omega_1 = (1 - self._alpha) * omega_0 + self._alpha * theta_0 | |
self.V_omega.weight.assign(omega_1) | |
def verify_spiral_sample(state_0, action, state_1, reward): | |
assert 0 <= state_0 and state_0 < 3, state_0 | |
assert 0 <= state_1 and state_1 < 3, state_1 | |
assert state_1 != ((state_0 + 1) % 3) | |
assert reward == 0, reward | |
class ExperimentRunner(tune.Trainable): | |
def _setup(self, variant): | |
# Set the current working directory such that the local mode | |
# logs into the correct place. This would not be needed on | |
# local/cluster mode. | |
if ray.worker._mode() == ray.worker.LOCAL_MODE: | |
os.chdir(os.getcwd()) | |
np.random.seed(np.random.randint(0, 10000)) | |
method_variant = variant['method'] | |
method_class = method_variant['class'] | |
method_config = method_variant['config'] | |
self.method = method_class(ValueFunction(), **method_config) | |
self.environment = TsitsiklisTriangle() | |
def _train(self): | |
for step in range(self.config['epoch_length']): | |
sample = self.environment.sample() | |
verify_spiral_sample(**sample) # Just a sanity check | |
self.method.update_V(**sample) | |
current_V = self.method.V(self.environment.states[..., None]) | |
MSE = tf.losses.MSE( | |
y_pred=current_V[:, 0], y_true=self.environment.optimal_V) | |
result = { | |
'MSE': MSE.numpy(), | |
'weight': self.method.V.weight.numpy().item(), | |
'training_steps': ( | |
(self.iteration + 1) * self.config['epoch_length']) | |
} | |
result['done'] = self.config['num_steps'] < result['training_steps'] | |
return result | |
def train(num_samples, | |
num_steps, # Total training steps | |
epoch_length): # Log interval | |
LEARNING_RATES = (1e-3, 1e-2, 1e-1, 1.0) # Learning rates to sweep over | |
GAMMA = 0.9 # Discount | |
METHOD_VARIANTS = [ | |
{ | |
'class': TD0, | |
'class_name': tune.sample_from( | |
lambda spec: spec['config']['method']['class'].__name__), | |
'config': { | |
'alpha': alpha, | |
'gamma': GAMMA, | |
}, | |
} for alpha in LEARNING_RATES | |
] + [ | |
{ | |
'class': Bilevel, | |
'class_name': tune.sample_from( | |
lambda spec: spec['config']['method']['class'].__name__), | |
'config': { | |
'alpha': alpha, | |
'beta': beta, | |
'gamma': GAMMA, | |
}, | |
} | |
for alpha, beta in itertools.product(LEARNING_RATES, repeat=2) | |
# if alpha < beta | |
] | |
datetime_stamp = datetime.datetime.now().strftime( | |
'%Y{d}%m{d}%dT%H{d}%M{d}%S'''.format(d="-")) | |
# Set this to `local_mode=True` if you want to debug with breakpoints. | |
ray.init(local_mode=False) | |
result = tune.run( | |
ExperimentRunner, | |
name=datetime_stamp, | |
config={ | |
'epoch_length': epoch_length, | |
'num_steps': num_steps, | |
'method': tune.grid_search(METHOD_VARIANTS), | |
}, | |
num_samples=num_samples, | |
local_dir=os.path.join(os.getcwd(), 'data'), | |
checkpoint_freq=0, | |
checkpoint_at_end=False, | |
max_failures=0, | |
with_server=False, | |
loggers=(ray.tune.logger.CSVLogger, ray.tune.logger.JsonLogger), | |
) | |
visualize(result) | |
if __name__ == '__main__': | |
parser = argparse.ArgumentParser() | |
parser.add_argument( | |
'--mode', | |
type=str, | |
choices=('train', 'visualize'), | |
default='visualize') | |
parser.add_argument('--num-samples', type=int, default=25) | |
parser.add_argument('--num-steps', type=int, default=1000) | |
parser.add_argument('--epoch-length', type=int, default=25) | |
parser.add_argument('--experiment-path', type=str, default=None) | |
args = parser.parse_args() | |
if args.mode == 'train': | |
train(num_samples=args.num_samples, | |
num_steps=args.num_steps, | |
epoch_length=args.epoch_length) | |
elif args.mode == 'visualize': | |
if args.experiment_path is None: | |
raise ValueError("Set '--experiment-path [path-to-experiment]'.") | |
visualize_experiment(args.experiment_path) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
dm-tree==0.1.2 | |
matplotlib==3.2.1 | |
pandas==1.0.1 | |
ray[tune]==0.8.3 | |
seaborn==0.10.0 | |
tensorflow==2.1.0 |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import tree | |
from .main import TsitsiklisTriangle | |
tsitsiklis_triangle = TsitsiklisTriangle() | |
all_transitions = tsitsiklis_triangle._generate_all_transitions() | |
samples = tree.map_structure( | |
lambda *x: np.stack(x), | |
*[tsitsiklis_triangle.sample() for i in range(1000)]) | |
states_0, counts_0 = np.unique(samples['state_0'], return_counts=True) | |
np.testing.assert_equal(states_0, (0, 1, 2)) | |
assert all(300 < counts_0) | |
states_1, counts_1 = np.unique(samples['state_1'], return_counts=True) | |
np.testing.assert_equal(states_1, (0, 1, 2)) | |
assert all(300 < counts_1) | |
states_0_1 = np.concatenate(( | |
samples['state_0'][..., None], | |
samples['state_1'][..., None] | |
), axis=1) | |
assert np.all(samples['state_1'] != ((samples['state_0'] + 1) % 3)) | |
np.testing.assert_equal(samples['action'], 0) | |
np.testing.assert_equal(samples['reward'], 0) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import tensorflow as tf | |
class ValueFunctionV2(tf.keras.Model): | |
def __init__(self, initial_weight=0.0, epsilon=5e-2): | |
super(ValueFunctionV2, self).__init__() | |
self._initial_weight = initial_weight | |
self._epsilon = epsilon | |
self.weight = self.add_weight( | |
'weight', | |
shape=[1], | |
initializer=tf.initializers.constant(initial_weight)) | |
def call(self, inputs): | |
tf.debugging.assert_shapes([[inputs, (None, 1)]]) | |
lambda_ = tf.sqrt(3.0) / 2.0 | |
A = tf.constant(([100.0], [-70.0], [-30.0])) | |
B = tf.constant(([23.094], [-98.15], [75.056])) | |
V = tf.exp(self._epsilon * self.weight) * ( | |
A * tf.cos(lambda_ * self.weight) | |
- B * tf.sin(lambda_ * self.weight)) | |
result = tf.gather_nd(V, inputs) | |
tf.debugging.assert_equal(tf.rank(result), 2) | |
tf.debugging.assert_equal(tf.shape(result), tf.shape(inputs)) | |
tf.debugging.assert_equal(tf.shape(result)[1], 1) | |
return result | |
def get_config(self): | |
config = { | |
'initial_weight': self._initial_weight, | |
'epsilon': self._epsilon, | |
} | |
return config |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import numpy as np | |
from ray import tune | |
import tree | |
import pandas as pd | |
import matplotlib.pyplot as plt | |
import seaborn as sns | |
def visualize(result): | |
# sns.set_style(style='dark') | |
trial_dataframes = result.fetch_trial_dataframes() | |
trial_configs = result.get_all_configs() | |
def create_visualization_dataframe(dataframe, config): | |
label = ( | |
f"{config['method']['class_name']}: " | |
+ ", ".join(f"{key}={value:.0e}" | |
for key, value | |
in sorted(config['method']['config'].items()))) | |
for greek_alphabet in ("alpha", "beta", "gamma"): | |
label = label.replace(greek_alphabet, f"$\{greek_alphabet}$") | |
dataframe['label'] = label | |
return dataframe[['MSE', 'weight', 'label', 'training_steps']] | |
visualization_dataframes = tree.map_structure_up_to( | |
{x: None for x in trial_dataframes}, | |
create_visualization_dataframe, trial_dataframes, trial_configs) | |
visualization_dataframe = pd.concat(visualization_dataframes.values()) | |
default_figsize = plt.rcParams.get('figure.figsize') | |
figsize = np.array((2, 1)) * np.max(default_figsize[0]) | |
figure, axes = plt.subplots(1, 2, figsize=figsize) | |
sns.lineplot( | |
x='training_steps', | |
y='MSE', | |
hue='label', | |
data=visualization_dataframe, | |
legend=False, | |
ax=axes[0]) | |
sns.lineplot( | |
x='training_steps', | |
y='weight', | |
hue='label', | |
data=visualization_dataframe, | |
legend='brief', | |
ax=axes[1]) | |
axes[0].set_ylim(np.clip(axes[0].get_ylim(), -0.01, 5.0)) | |
axes[1].set_yscale('symlog') | |
legend = axes[1].legend( | |
loc='center left', | |
bbox_to_anchor=(1.05, 0.5), | |
ncol=1, | |
borderaxespad=0.0) | |
plt.savefig( | |
os.path.join(result._experiment_dir, 'result.pdf'), | |
bbox_extra_artists=(legend, ), | |
bbox_inches='tight') | |
plt.savefig( | |
os.path.join(result._experiment_dir, 'result.png'), | |
bbox_extra_artists=(legend, ), | |
bbox_inches='tight') | |
def visualize_experiment(experiment_path): | |
result = tune.Analysis(experiment_path) | |
visualize(result) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment