Created
July 15, 2021 07:41
-
-
Save stefanbschneider/ebbc1f9f4da273aa8886cf0ab92ec081 to your computer and use it in GitHub Desktop.
DeepCoord agent with SAC
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# DeepCoord uses DRL for network and service coordination | |
# By default, it uses DDPG from the keras-rl library | |
# https://github.com/RealVNF/DeepCoord | |
# This module shows how to implement the DRL agent with SAC from stable-baselines instead | |
# It's copied from https://github.com/RealVNF/rl-coordination/blob/master/src/rlsp/agents/rlsp_sac.py (private repo) | |
from rlsp.agents.rlsp_agent import RLSPAgent | |
from rlsp.utils.util_functions import create_simulator | |
from stable_baselines.sac import SAC | |
from stable_baselines.sac.policies import MlpPolicy | |
import numpy as np | |
from gym.spaces import Box | |
import csv | |
import os | |
import logging | |
logger = logging.getLogger(__name__) | |
class SAC_Agent(RLSPAgent): | |
"""SAC Agent for RLSP""" | |
def __init__(self, agent_helper, logger): | |
self.agent_helper = agent_helper | |
self.callbacks_prepared = False | |
shape_flattened = (np.prod(self.agent_helper.env.env_limits.scheduling_shape),) | |
self.agent_helper.env.action_space = Box(low=-1, high=1, | |
shape=shape_flattened) | |
self.agent = SAC(MlpPolicy, agent_helper.env, | |
gamma=self.agent_helper.config['gamma'], | |
learning_rate=self.agent_helper.config['learning_rate'], | |
buffer_size=self.agent_helper.config['buffer_size'], | |
learning_starts=self.agent_helper.config['learning_starts'], | |
train_freq=self.agent_helper.config['train_freq'], | |
batch_size=self.agent_helper.config['batch_size'], | |
tau=self.agent_helper.config['tau'], | |
ent_coef=self.agent_helper.config['ent_coef'], | |
target_update_interval=self.agent_helper.config['target_update_interval'], | |
gradient_steps=self.agent_helper.config['gradient_steps'], | |
target_entropy=self.agent_helper.config['target_entropy'], | |
random_exploration=self.agent_helper.config['random_exploration'], | |
policy_kwargs={'layers': self.agent_helper.config['hidden_layers']}, | |
tensorboard_log='./') | |
def fit(self, env, episodes, verbose, episode_steps, callbacks, log_interval): | |
"""Mask the agent fit function""" | |
steps = episodes * self.agent_helper.episode_steps | |
self.agent.learn(steps, callback=self._callbacks, tb_log_name=self.agent_helper.graph_path) | |
self.close_callbacks() | |
def test(self, env, episodes, verbose, episode_steps, callbacks): | |
"""Mask the agent fit function""" | |
# Check to see if the test is called after training. Causes duplicate CSV headers | |
# when agent is called only for testing. | |
if self.agent_helper.train: | |
# Create a fresh simulator with test argument | |
self.agent_helper.env.simulator = create_simulator(self.agent_helper) | |
self.callbacks_prepared = False | |
obs = self.agent_helper.env.reset() | |
locals_ = {} | |
locals_['episode_rewards'] = [0] | |
for i in range(episodes * self.agent_helper.episode_steps): | |
action, _states = self.agent.predict(obs) | |
obs, reward, done, info = self.agent_helper.env.step(action) | |
# Rough implementation of callbacks | |
locals_['step'] = i | |
locals_['reward'] = reward | |
locals_['episode_rewards'][-1] += reward | |
self._callbacks(locals_, {}) # Call callbacks before adding a new episode reward | |
if done: | |
logger.info(f"Finished testing step {i+1}. Episode reward = {locals_['episode_rewards'][-1]}") | |
locals_['episode_rewards'].append(reward) | |
self.close_callbacks() | |
def save_weights(self, weights_file, overwrite=True): | |
logger.info("saving model and weights to %s", weights_file) | |
dir_path = os.path.dirname(os.path.realpath(weights_file)) | |
os.makedirs(dir_path, exist_ok=True) | |
self.agent.save(f'{weights_file}weights') | |
def load_weights(self, weights_file): | |
self.agent = SAC.load(weights_file) | |
def _callbacks(self, locals_, globals_): | |
# write the reward to a csv. | |
if not self.callbacks_prepared: | |
self.prepare_callbacks() | |
self.run_reward_csv_writer.writerow([locals_['step'], locals_.get('reward')]) | |
if (locals_['step'] != 1) and ((locals_['step'] + 1) % self.agent_helper.episode_steps == 0): | |
self.episode_reward_csv_writer.writerow([len(locals_['episode_rewards']), locals_['episode_rewards'][-1]]) | |
def prepare_callbacks(self): | |
self.run_reward_file = open(f"{self.agent_helper.config_dir}run_reward.csv", 'a+', newline='') | |
self.run_reward_csv_writer = csv.writer(self.run_reward_file) | |
self.episode_rewards_file = open(f"{self.agent_helper.config_dir}episode_reward.csv", 'a+', newline='') | |
self.episode_reward_csv_writer = csv.writer(self.episode_rewards_file) | |
self.run_reward_csv_writer.writerow(['run', 'reward']) # add a header | |
self.episode_reward_csv_writer.writerow(['episode', 'reward']) # add a header | |
self.callbacks_prepared = True | |
def close_callbacks(self): | |
self.episode_rewards_file.close() | |
self.run_reward_file.close() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment