This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| from stable_baselines.common.policies import FeedForwardPolicy | |
| from stable_baselines import A2C | |
| # Custom MLP policy of three layers of size 128 each | |
| class CustomPolicy(FeedForwardPolicy): | |
| def __init__(self, *args, **kwargs): | |
| super(CustomPolicy, self).__init__(*args, **kwargs, | |
| layers=[128, 128, 128], | |
| feature_extraction="mlp") |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| from stable_baselines import PPO2 | |
| # Define and train a model in one line of code ! | |
| trained_model = PPO2('MlpPolicy', 'CartPole-v1').learn(total_timesteps=10000) | |
| # you can then access the gym env using trained_model.get_env() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import os | |
| import gym | |
| import numpy as np | |
| import matplotlib.pyplot as plt | |
| from stable_baselines.common.vec_env.dummy_vec_env import DummyVecEnv | |
| from stable_baselines.bench import Monitor | |
| from stable_baselines.results_plotter import load_results, ts2xy | |
| from stable_baselines import DDPG |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| from stable_baselines.common.cmd_util import make_atari_env | |
| from stable_baselines.common.policies import CnnPolicy | |
| from stable_baselines.common.vec_env import VecFrameStack | |
| from stable_baselines import ACER | |
| # There already exists an environment generator | |
| # that will make and wrap atari environments correctly. | |
| # Here we are also multiprocessing training (num_env=4 => 4 processes) | |
| env = make_atari_env('PongNoFrameskip-v4', num_env=4, seed=0) | |
| # Frame-stacking with 4 frames |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import imageio | |
| import numpy as np | |
| from stable_baselines.common.policies import MlpPolicy | |
| from stable_baselines import A2C | |
| model = A2C(MlpPolicy, "LunarLander-v2").learn(100000) | |
| images = [] | |
| obs = model.env.reset() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| from stable_baselines.common.cmd_util import make_atari_env | |
| from stable_baselines.common.policies import CnnPolicy | |
| from stable_baselines import PPO2 | |
| # There already exists an environment generator | |
| # that will make and wrap atari environments correctly | |
| env = make_atari_env('DemonAttackNoFrameskip-v4', num_env=8, seed=0) | |
| model = PPO2(CnnPolicy, env, verbose=1) | |
| model.learn(total_timesteps=10000) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import gym | |
| from stable_baselines.common.policies import MlpPolicy | |
| from stable_baselines.common.vec_env import DummyVecEnv, VecNormalize | |
| from stable_baselines import PPO2 | |
| env = DummyVecEnv([lambda: gym.make("Reacher-v2")]) | |
| # Automatically normalize the input features | |
| env = VecNormalize(env, norm_obs=True, norm_reward=False, | |
| clip_obs=10.) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| tensorboard --logdir /tmp/a2c_cartpole_tensorboard/ |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import pytest | |
| import numpy as np | |
| from stable_baselines import A2C, ACER, ACKTR, DQN, DDPG, PPO1, PPO2, TRPO | |
| from stable_baselines.common import set_global_seeds | |
| MODEL_LIST_DISCRETE = [ | |
| A2C, | |
| ACER, | |
| ACKTR, |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import gym | |
| import numpy as np | |
| import cma | |
| from collections import OrderedDict | |
| from stable_baselines import A2C | |
| def flatten(params): | |
| """ |