Created
January 23, 2024 21:28
-
-
Save mrdmnd/8fb5774a396d22d6ec6a469feb911870 to your computer and use it in GitHub Desktop.
outlaw rogue DQN toy problem
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
from outlaw_environment import OutlawEnvironment | |
from stable_baselines3 import PPO | |
from stable_baselines3 import DQN | |
from stable_baselines3.common.callbacks import BaseCallback | |
from stable_baselines3.common.env_checker import check_env | |
from stable_baselines3.common.evaluation import evaluate_policy | |
env = OutlawEnvironment() | |
check_env(env, warn=True) | |
model = DQN("MlpPolicy", | |
env, | |
learning_rate=0.02, | |
verbose=1, | |
tensorboard_log='./tensorboard_logdir/', | |
) | |
model.learn( | |
total_timesteps=int(1e5), | |
progress_bar=True, | |
) | |
vec_env = model.get_env() | |
obs = vec_env.reset() | |
for _ in range(20): | |
vec_env.render("console") | |
action, _states = model.predict(obs, deterministic=True) | |
print("best action: ", action[0]) | |
obs, rewards, dones, info = vec_env.step(action) | |
print("reward: ", rewards[0]) | |
# poetry run python3 outlaw_agent.py |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import gymnasium as gym | |
import numpy as np | |
BROADSIDES = 0 | |
SKULL = 0 | |
# Actions are just (impure) functions that return the same tuple type as the Environment step() fn. | |
# Allowed to modify agent_state object! | |
def PistolShot(agent_state): | |
BASE_DAM = 8708 | |
fth = agent_state.fth_stacks | |
current_cp = agent_state.combo_points | |
cp_generated = 1 | |
cp_generated += (3 if fth > 0 else 0) | |
cp_generated *= (2 if BROADSIDES else 1) | |
agent_state.fth_stacks = max(0, fth-1) | |
agent_state.combo_points = min(7, current_cp + cp_generated) | |
reward = 3 * BASE_DAM if (fth > 0) else BASE_DAM | |
terminated = False | |
truncated = False | |
observation = agent_state.get_observation() | |
info = agent_state.get_info() | |
return (observation, reward, terminated, truncated, info) | |
def SinisterStrike(agent_state): | |
BASE_DAM = 13896 | |
fth = agent_state.fth_stacks | |
current_cp = agent_state.combo_points | |
cp_generated = 2 if BROADSIDES else 1 | |
prob_double = (0.45 + 0.25 * (1 if SKULL else 0)) | |
if np.random.random() < prob_double: | |
agent_state.fth_stacks = min(2, fth+1) | |
agent_state.combo_points = min(7, current_cp + 2 * cp_generated) | |
reward = 2 * BASE_DAM | |
else: | |
agent_state.combo_points = min(7, current_cp + 1 * cp_generated) | |
reward = BASE_DAM | |
terminated = False | |
truncated = False | |
observation = agent_state.get_observation() | |
info = agent_state.get_info() | |
return (observation, reward, terminated, truncated, info) | |
def Dispatch(agent_state): | |
BASE_DAM = 7842 | |
current_cp = agent_state.combo_points | |
new_cp = 0 if current_cp <= 5 else 1 | |
prob_extra_cp = (0.2 * current_cp) if current_cp <= 5 else (0.2 * (current_cp-5)) | |
reward = current_cp * BASE_DAM | |
# Ruthlessness proc or not. | |
agent_state.combo_points = new_cp + (1 if np.random.random() <= prob_extra_cp else 0) | |
terminated = False | |
truncated = False | |
observation = agent_state.get_observation() | |
info = agent_state.get_info() | |
return (observation, reward, terminated, truncated, info) | |
# A helper type to keep track of our agent's state. | |
class AgentState(): | |
def __init__(self): | |
self.fth_stacks = 0 # normalized to 0, 1, 2 instead of 0, 3, 6 | |
self.combo_points = 0 | |
def __repr__(self): | |
return f"{self.fth_stacks}, {self.combo_points}" | |
def describe_observation_space(self): | |
return gym.spaces.Box(low=np.array([0, 0]), high=np.array([2, 7]), shape=(2,), dtype=np.uint8) | |
def get_observation(self): | |
return np.array([self.fth_stacks, self.combo_points], dtype=np.uint8) | |
def get_info(self): | |
return {} | |
class OutlawEnvironment(gym.Env): | |
metadata = {"render_modes": ["console"], "render_fps": 30} | |
# A list of functions that take an agent_state, (potentially) mutate it, and return a tuple of the same form that step() is looking for. | |
ACTIONS = [ | |
PistolShot, | |
SinisterStrike, | |
Dispatch, | |
] | |
def __init__(self, render_mode="console"): | |
super(OutlawEnvironment, self).__init__() | |
self.render_mode = render_mode | |
self.agent_state = AgentState() | |
self.action_space = gym.spaces.Discrete(len(OutlawEnvironment.ACTIONS)) | |
self.observation_space = self.agent_state.describe_observation_space() | |
def reset(self, seed=None, options=None): | |
super().reset(seed=seed, options=options) | |
self.sim_time = 0 | |
self.agent_state = AgentState() | |
return (self.agent_state.get_observation(), self.agent_state.get_info()) | |
def step(self, action): | |
self.sim_time += 1 | |
return OutlawEnvironment.ACTIONS[action](self.agent_state) | |
def render(self): | |
if self.render_mode == "console": | |
print(self.agent_state) | |
def close(self): | |
pass |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment