Skip to content

Instantly share code, notes, and snippets.

@DuaneNielsen
Created September 8, 2024 18:34
Show Gist options
  • Save DuaneNielsen/cf15f389f13e054087d6281eb7f8f66d to your computer and use it in GitHub Desktop.
Save DuaneNielsen/cf15f389f13e054087d6281eb7f8f66d to your computer and use it in GitHub Desktop.
AlphaZero Evolutionary Strategies
import gymnax
import flax.linen as nn
from argparse import ArgumentParser
import functools
from typing import Any, Callable, Sequence, Tuple, Optional, Sequence, Dict, Union, Iterable
import mctx
import warnings
from typing import Tuple, Optional, Union
import jax
import jax.numpy as jnp
from evosax import ParameterReshaper, FitnessShaper, OpenES
warnings.simplefilter(action='ignore', category=FutureWarning)
import pickle
import os
os.environ['QT_QPA_PLATFORM'] = 'offscreen'
PRNGKey = Any
Array = Any
Shape = Tuple[int]
Dtype = Any
parser = ArgumentParser()
parser.add_argument('--aznet_channels', type=int, default=8)
parser.add_argument('--aznet_blocks', type=int, default=5)
parser.add_argument('--aznet_layernorm', type=str, default='None')
# parser.add_argument('--num_hidden_units', type=int, default=16)
# parser.add_argument('--num_linear_layers', type=int, default=1)
parser.add_argument('--num_output_units', type=int, default=3)
parser.add_argument('--output_activation', type=str, default='categorical')
parser.add_argument('--env_name', type=str, default='Breakout-MinAtar')
parser.add_argument('--popsize', type=int, default=1600)
parser.add_argument('--sigma_init', type=float, default=0.03)
parser.add_argument('--sigma_decay', type=float, default=1.0)
parser.add_argument('--sigma_limit', type=float, default=0.01)
parser.add_argument('--lrate_init', type=float, default=0.05)
parser.add_argument('--lrate_decay', type=float, default=1.0)
parser.add_argument('--lrate_limit', type=float, default=0.001)
parser.add_argument('--num_generations', type=int, default=320)
parser.add_argument('--num_mc_evals', type=int, default=1)
parser.add_argument('--network', type=str, default='AZnet') # so it appears as a hyperparameter in wandb
parser.add_argument('--seed', type=int, default=0)
parser.add_argument('--num_simulations', type=int, default=16)
parser.add_argument('--visualize_off', action='store_true')
parser.add_argument('--max_epi_len', type=int, default=1000)
args = parser.parse_args()
rng = jax.random.PRNGKey(args.seed)
class ResnetV2Block(nn.Module):
features: int
kernel_size: Union[int, Iterable[int]] = (3, 3)
kernel_init: functools.partial = nn.initializers.lecun_normal()
bias_init: functools.partial = nn.initializers.zeros
block_name: str = None
dtype: str = 'float32'
param_dict: Optional[Dict] = None
@nn.compact
def __call__(self, x):
i = x
if args.aznet_layernorm == 'layernorm':
x = nn.LayerNorm(use_bias=False, use_scale=False)(x)
x = nn.relu(x)
x = nn.Conv(
features=self.features,
kernel_size=self.kernel_size,
padding=((1, 1), (1, 1)),
kernel_init=self.kernel_init if self.param_dict is None else lambda *_: jnp.array(
self.param_dict['conv1']['weight']),
use_bias=False,
dtype=self.dtype
)(x)
if args.aznet_layernorm == 'layernorm':
x = nn.LayerNorm(use_bias=False, use_scale=False)(x)
x = nn.relu(x)
x = nn.Conv(
features=self.features,
kernel_size=self.kernel_size,
padding=((1, 1), (1, 1)),
kernel_init=self.kernel_init if self.param_dict is None else lambda *_: jnp.array(
self.param_dict['conv2']['weight']),
use_bias=False,
dtype=self.dtype
)(x)
return x + i
class AZNet(nn.Module):
num_actions: int = args.num_output_units
num_channels: int = args.aznet_channels
num_blocks: int = args.aznet_blocks
dtype: str = 'float32'
kernel_init: functools.partial = nn.initializers.lecun_normal()
bias_init: functools.partial = nn.initializers.zeros
param_dict: Optional[Dict] = None
@nn.compact
def __call__(self, x, rng):
x = x.astype(self.dtype)
x = nn.Conv(
features=self.num_channels,
kernel_size=(3, 3),
padding=((1, 1), (1, 1)),
kernel_init=self.kernel_init if self.param_dict is None
else lambda *_: jnp.array(self.param_dict['conv1']['weight']),
use_bias=False,
dtype=self.dtype,
)(x)
if args.aznet_layernorm == 'layernorm':
x = nn.LayerNorm(use_bias=False, use_scale=False)(x)
x = nn.relu(x)
for block in range(self.num_blocks):
x = ResnetV2Block(
features=self.num_channels,
)(x)
if args.aznet_layernorm == 'layernorm':
x = nn.LayerNorm(use_bias=False, use_scale=False)(x)
x = nn.relu(x)
# policy and value heads here, below code is placeholder for ES testing
x = x.reshape(-1)
policy = nn.Dense(features=self.num_actions,
kernel_init=self.kernel_init if self.param_dict is None
else lambda *_: jnp.array(self.param_dict['fc1']['weight']),
bias_init=self.bias_init if self.param_dict is None
else lambda *_: jnp.array(self.param_dict['fc1']['bias']),
dtype=self.dtype,
)(x)
value = nn.Dense(features=1,
kernel_init=self.kernel_init if self.param_dict is None
else lambda *_: jnp.array(self.param_dict['fc2']['weight']),
bias_init=self.bias_init if self.param_dict is None
else lambda *_: jnp.array(self.param_dict['fc2']['bias']),
dtype=self.dtype,
)(x)
return policy, value
# x_out = jax.random.categorical(rng, x)
# return x_out
# Create placeholder params for env
env, env_params = gymnax.make(args.env_name)
pholder = jnp.zeros(env.observation_space(env_params).shape)
model = AZNet(
num_actions=args.num_output_units
)
init_policy_params = model.init(
rng,
x=pholder,
rng=rng
)
"""Rollout wrapper for gymnax environments."""
import functools
from typing import Any, Optional
import jax
import jax.numpy as jnp
import gymnax
class RolloutWrapper(object):
"""Wrapper to define batch evaluation for generation parameters."""
def __init__(
self,
model_forward=None,
env_name: str = "Pendulum-v1",
num_env_steps: Optional[int] = None,
env_kwargs: Any | None = None,
env_params: Any | None = None,
):
"""Wrapper to define batch evaluation for generation parameters."""
self.env_name = env_name
# Define the RL environment & network forward function
if env_kwargs is None:
env_kwargs = {}
if env_params is None:
env_params = {}
self.env, self.env_params = gymnax.make(self.env_name, **env_kwargs)
self.env_params = self.env_params.replace(**env_params)
self.model_forward = model_forward
if num_env_steps is None:
self.num_env_steps = self.env_params.max_steps_in_episode
else:
self.num_env_steps = num_env_steps
@functools.partial(jax.jit, static_argnums=(0,2,))
def population_rollout(self, rng_eval, num_simulations, policy_params):
"""Reshape parameter vector and evaluate the generation."""
# Evaluate population of nets on gymnax task - vmap over rng & params
pop_rollout = jax.vmap(self.batch_rollout, in_axes=(None, None, 0))
return pop_rollout(rng_eval, num_simulations, policy_params)
@functools.partial(jax.jit, static_argnums=(0,2,))
def batch_rollout(self, rng_eval, num_simulations, policy_params):
"""Evaluate a generation of networks on RL/Supervised/etc. task."""
# vmap over different MC fitness evaluations for single network
batch_rollout = jax.vmap(self.single_rollout, in_axes=(0, None, None))
return batch_rollout(rng_eval, num_simulations, policy_params)
@functools.partial(jax.jit, static_argnums=(0,2,))
def single_rollout(self, rng_input, num_simulations, policy_params):
"""Rollout a pendulum episode with lax.scan."""
# Reset the environment
rng_reset, rng_episode = jax.random.split(rng_input)
obs, state = self.env.reset(rng_reset, self.env_params)
def recurrent_fn(model_params, rng_key: jnp.ndarray, action: jnp.ndarray, state_input):
# model: params
# state: embedding
rng_key, rng_step, rng_net = jax.random.split(rng_key, 3)
obs, state, policy_params, rng, cum_reward, valid_mask, step, done = (
jax.tree.map(lambda s: jnp.squeeze(s, axis=0), state_input))
# model_params, model_state = model
# current_player = obs.current_player
# step the environment
next_obs, next_state, reward, done, _ = self.env.step(
rng_step, state, action.squeeze(0), self.env_params
)
# store the transition and reward
new_cum_reward = cum_reward + reward * valid_mask
new_valid_mask = valid_mask * (1 - done)
state_input = [next_obs, next_state, policy_params, rng_key, new_cum_reward, new_valid_mask, step+1, done]
state_input = jax.tree.map(lambda s: jnp.expand_dims(s, axis=0), state_input)
# compute the logits and values for the next state
logits, value = self.model_forward(model_params, obs, rng_net)
# mask invalid actions
logits = logits - jnp.max(logits, axis=-1, keepdims=True)
# logits = jnp.where(state.legal_action_mask, logits, jnp.finfo(logits.dtype).min)
# reward = state.rewards # [jnp.arange(state.rewards.shape[0]), current_player]
value = jnp.where(done, 0.0, value)
# discount = -1.0 * jnp.ones_like(value)
discount = 0.99 * jnp.ones_like(value)
discount = jnp.where(done, 0.0, discount)
recurrent_fn_output = mctx.RecurrentFnOutput(
reward=jnp.expand_dims(reward, axis=0),
discount=discount,
prior_logits=jnp.expand_dims(logits, axis=0),
value=value,
)
return recurrent_fn_output, state_input
def policy_step(state_input, _):
"""lax.scan compatible step transition in jax env."""
obs, state, policy_params, rng, cum_reward, valid_mask, step, done = (
jax.tree.map(lambda s: jnp.squeeze(s, axis=0), state_input))
key1, rng_step, net_rng, rng_action = jax.random.split(rng, 4)
logits, value = self.model_forward(policy_params, obs, net_rng)
logits = jnp.expand_dims(logits, axis=0)
root = mctx.RootFnOutput(prior_logits=logits, value=value, embedding=state_input)
policy_output = mctx.gumbel_muzero_policy(
params=policy_params,
rng_key=key1,
root=root,
recurrent_fn=recurrent_fn,
num_simulations=num_simulations,
invalid_actions=jnp.zeros_like(logits),
qtransform=mctx.qtransform_completed_by_mix_value,
gumbel_scale=1.0,
)
# action = jax.random.categorical(rng_action, jnp.log(policy_output.action_weights), axis=-1)
action = policy_output.action.squeeze(0)
# keys = jax.random.split(key2, batch_size)
# step the environment
next_obs, next_state, reward, done, _ = self.env.step(
rng_step, state, action, self.env_params
)
new_cum_reward = cum_reward + reward * valid_mask
new_valid_mask = valid_mask * (1 - done)
carry = [
next_obs,
next_state,
policy_params,
rng,
new_cum_reward,
new_valid_mask,
step + 1,
done,
]
carry = jax.tree.map(lambda s: jnp.expand_dims(s, axis=0), carry)
y = [obs, state, action, reward, next_obs, done]
return carry, y
def early_termination_loop_with_trajectory(policy_step, max_steps, initial_state):
def cond_fn(carry):
state, _ = carry
_, _, _, _, _, _, step, done = state
return jnp.logical_and(jnp.logical_not(jnp.all(done)), jnp.any(step < max_steps))
def body_fn(carry):
state, trajectory = carry
next_state, step_output = policy_step(state, ())
step_counter = next_state[6]
updated_trajectory = jax.tree_util.tree_map(
lambda t, s: t.at[step_counter - 1].set(s),
trajectory,
step_output
)
return (next_state, updated_trajectory)
trajectory_template = policy_step(initial_state, ())[1]
initial_trajectory = jax.tree_util.tree_map(
lambda x: jnp.zeros((max_steps,) + x.shape, dtype=x.dtype),
trajectory_template
)
initial_carry = (initial_state, initial_trajectory)
final_state, final_trajectory = jax.lax.while_loop(cond_fn, body_fn, initial_carry)
return final_state, final_trajectory
state_input = [
obs,
state,
policy_params,
rng_episode,
jnp.array([0.0]),
jnp.array([1.0]),
jnp.array(0),
jnp.array(False),
]
state_input = jax.tree.map(lambda s: jnp.expand_dims(s, axis=0), state_input)
# Run the early termination loop
carry_out, scan_out = early_termination_loop_with_trajectory(policy_step, args.max_epi_len, state_input)
# Return the sum of rewards accumulated by agent in episode rollout
obs, state, action, reward, next_obs, done = scan_out
cum_return, steps = carry_out[-4], carry_out[-2]
return obs, state, action, reward, next_obs, done, cum_return, steps
@property
def input_shape(self):
"""Get the shape of the observation."""
rng = jax.random.PRNGKey(0)
obs, _ = self.env.reset(rng, self.env_params)
return obs.shape
# Define rollout manager for env
manager = RolloutWrapper(model.apply, env_name=args.env_name, num_env_steps=args.max_epi_len)
# from evosax import OpenES
# Helper for parameter reshaping into appropriate datastructures
param_reshaper = ParameterReshaper(init_policy_params, n_devices=1)
# Instantiate and initialize the evolution strategy
strategy = OpenES(popsize=args.popsize,
num_dims=param_reshaper.total_params,
opt_name="sgd")
es_params = strategy.default_params
es_params = es_params.replace(sigma_init=args.sigma_init, sigma_decay=args.sigma_decay, sigma_limit=args.sigma_limit)
es_params = es_params.replace(opt_params=es_params.opt_params.replace(
lrate_init=args.lrate_init, lrate_decay=args.lrate_decay, lrate_limit=args.lrate_limit))
es_state = strategy.initialize(rng, es_params)
fit_shaper = FitnessShaper(maximize=True, centered_rank=True)
import wandb
wandb.init(project=f'alphazero-es-{args.env_name}', config=args.__dict__, settings=wandb.Settings(code_dir="."))
# num_generations = 100
# num_mc_evals = 128
print_every_k_gens = 1
best_eval = -jnp.inf
total_frames = 0
for gen in range(args.num_generations):
rng, rng_init, rng_ask, rng_rollout, rng_eval = jax.random.split(rng, 5)
# Ask for candidates to evaluate
x, es_state = strategy.ask(rng_ask, es_state)
# Reshape parameters into flax FrozenDicts
reshaped_params = param_reshaper.reshape(x)
rng_batch_rollout = jax.random.split(rng_rollout, args.num_mc_evals)
# Perform population evaluation
_, _, _, _, _, _, cum_ret, steps = manager.population_rollout(rng_batch_rollout, args.num_simulations, reshaped_params)
# Mean over MC rollouts, shape fitness and update strategy
fitness = cum_ret.mean(axis=1).squeeze()
fit_re = fit_shaper.apply(x, fitness)
es_state = strategy.tell(x, fit_re, es_state)
# param_new = param_prev - grads * state.lrate
total_frames += steps.sum().item()
log = {'score_hist': fitness, 'score': fitness.mean(), 'max_steps': jnp.max(steps), 'steps_hist': steps, 'total_frames': total_frames}
status = f"Generation: {gen + 1} fitness: {fitness.mean():.3f}, max_steps: {jnp.max(steps)}, total frames: {total_frames}"
if (gen + 1) % print_every_k_gens == 0:
eval_params = param_reshaper.reshape(jnp.expand_dims(es_state.mean, axis=0))
eval_params = jax.tree.map(lambda p: p.squeeze(0), eval_params)
for num_simulations in [args.num_simulations, args.num_simulations * 2]:
# Simple single episode rollout for policy
rng_batch_eval_rollout = jax.random.split(rng_rollout, 20)
obs, state, action, reward, next_obs, done, cum_ret, steps = manager.batch_rollout(rng_batch_eval_rollout, num_simulations, eval_params)
eval_score = cum_ret.mean()
max_steps = jnp.max(steps)
longest_run = jnp.argmax(steps)
log.update({
f"eval_score_{num_simulations}": cum_ret.mean(),
f"eval_steps_{num_simulations}": jnp.max(steps),
})
# save if the model is better
if eval_score.item() > best_eval:
best_eval = max(best_eval, eval_score.item())
with open(f'{wandb.run.dir}/best_{best_eval:.2f}.param', 'wb') as f:
pickle.dump(eval_params, f)
# visualization
if not args.visualize_off:
print('generating visualization')
states = []
for i in range(max_steps.item()):
states.append(jax.tree.map(lambda x: x[longest_run, i], state))
from gymnax.visualize import Visualizer
vis = Visualizer(env, env_params, states)
anim_path = f"{wandb.run.dir}/anim.gif"
vis.animate(anim_path)
log.update({"viz": wandb.Video(anim_path, fps=8, format='gif')})
log.update({"best_eval": best_eval})
status += f' Eval fitness: {cum_ret.mean():.3f}, max_steps: {max_steps}'
wandb.log(log, step=gen)
print(status)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment