Created
January 7, 2022 05:44
-
-
Save 0xJchen/5995e8ca49e43d88bd399017d5c88438 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
from .lib import pyflare as fl | |
from .fish_env_basic import FishEnvBasic | |
import torch | |
from ray.rllib.env.apis.task_settable_env import TaskSettableEnv | |
class KoiCruisingEnv(FishEnvBasic, TaskSettableEnv): | |
def __init__(self, | |
control_dt=0.2, | |
wp= np.array([0.0,1.0]), | |
wr= 0.0, | |
wa=0.5, | |
max_time = 10, | |
done_dist=0.15, | |
radius = 2, | |
theta = np.array([90,90]), | |
phi = np.array([0,0]), | |
data_folder = "", | |
env_json :str = '../assets/env_file/env_cruising_koi.json', | |
gpuId: int=0, | |
couple_mode: fl.COUPLE_MODE = fl.COUPLE_MODE.TWO_WAY, | |
empirical_force_amplifier =1600, | |
is3D= False, | |
use_com=True | |
) -> None: | |
super().__init__(control_dt,wp, wr,wa,max_time,done_dist,radius,theta,phi,data_folder,env_json,gpuId,couple_mode,empirical_force_amplifier,is3D,use_com) | |
def sample_tasks(self, n_tasks): | |
sample_theta=np.random.choice((0,180),(n_tasks,)) | |
sample_phi=np.random.choice((0,360),(n_tasks,)) | |
return np.array([{"theta":theta,"phi":phi,"dist":0.01} for theta,phi in zip(sample_theta,sample_phi)]) | |
def set_task(self, task): | |
theta=task["theta"] | |
phi=task["phi"] | |
task_param={"theta":theta,"phi":phi,"dist":0.01} | |
super().set_task(task_param) | |
def get_task(self): | |
return {"theta":self.theta,"phi":self.phi} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Example of a custom gym environment and model. Run this for a demo. | |
This example shows: | |
- using a custom environment | |
- using a custom model | |
- using Tune for grid search to try different learning rates | |
You can visualize experiment results in ~/ray_results using TensorBoard. | |
Run example with defaults: | |
$ python custom_env.py | |
For CLI options: | |
$ python custom_env.py --help | |
""" | |
import argparse | |
import gym | |
from gym.spaces import Discrete, Box | |
import numpy as np | |
import os | |
import random | |
import ray | |
from ray import tune | |
from ray.rllib.agents import ppo | |
from ray.rllib.env.env_context import EnvContext | |
from ray.rllib.models import ModelCatalog | |
from ray.rllib.models.tf.tf_modelv2 import TFModelV2 | |
from ray.rllib.models.tf.fcnet import FullyConnectedNetwork | |
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2 | |
from ray.rllib.models.torch.fcnet import FullyConnectedNetwork as TorchFC | |
from ray.rllib.utils.framework import try_import_tf, try_import_torch | |
from ray.rllib.utils.test_utils import check_learning_achieved | |
from ray.tune.logger import pretty_print | |
from ray.rllib.examples.env.halfcheetah_rand_direc import HalfCheetahRandDirecEnv | |
tf1, tf, tfv = try_import_tf() | |
torch, nn = try_import_torch() | |
parser = argparse.ArgumentParser() | |
parser.add_argument( | |
"--run", | |
type=str, | |
default="PPO", | |
help="The RLlib-registered algorithm to use.") | |
parser.add_argument( | |
"--framework", | |
choices=["tf", "tf2", "tfe", "torch"], | |
default="tf", | |
help="The DL framework specifier.") | |
parser.add_argument( | |
"--as-test", | |
action="store_true", | |
help="Whether this script should be run as a test: --stop-reward must " | |
"be achieved within --stop-timesteps AND --stop-iters.") | |
parser.add_argument( | |
"--stop-iters", | |
type=int, | |
default=50, | |
help="Number of iterations to train.") | |
parser.add_argument( | |
"--stop-timesteps", | |
type=int, | |
default=100000, | |
help="Number of timesteps to train.") | |
parser.add_argument( | |
"--stop-reward", | |
type=float, | |
default=0.1, | |
help="Reward at which we stop training.") | |
parser.add_argument( | |
"--no-tune", | |
action="store_true", | |
help="Run without Tune using a manual train loop instead. In this case," | |
"use PPO without grid search and no TensorBoard.") | |
parser.add_argument( | |
"--local-mode", | |
action="store_true", | |
help="Init Ray in local mode for easier debugging.") | |
if __name__ == "__main__": | |
args = parser.parse_args() | |
print(f"Running with following CLI options: {args}") | |
ray.init(local_mode=args.local_mode) | |
def env_creator(env_name): | |
if env_name == 'CustomEnv-v0': | |
from gym_fish.envs.fish_env_cruising import KoiCruisingEnv as env | |
else: | |
raise NotImplementedError | |
return env | |
env = env_creator('CustomEnv-v0') | |
tune.register_env("myEnv", lambda cfg: env()) | |
config = { | |
"env": 'myEnv', | |
"num_gpus": int(os.environ.get("RLLIB_NUM_GPUS", "1")), | |
"model": { | |
"custom_model": "my_model", | |
"vf_share_layers": True, | |
}, | |
"num_workers": 1, # parallelism | |
"framework": args.framework, | |
"log_level": "DEBUG" | |
} | |
stop = { | |
"training_iteration": args.stop_iters, | |
"timesteps_total": args.stop_timesteps, | |
"episode_reward_mean": args.stop_reward, | |
} | |
if args.run != "PPO": | |
raise ValueError("Only support --run PPO with --no-tune.") | |
print("Running manual train loop without Ray Tune.") | |
ppo_config = ppo.DEFAULT_CONFIG.copy() | |
ppo_config.update(config) | |
# use fixed learning rate instead of grid search (needs tune) | |
ppo_config["lr"] = 1e-3 | |
trainer = ppo.PPOTrainer(config=ppo_config, env="myEnv") | |
# run manual training loop and print results after each iteration | |
for _ in range(args.stop_iters): | |
result = trainer.train() | |
print(pretty_print(result)) | |
ray.shutdown() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import gym | |
from typing import List, Any | |
TaskType = Any # Can be different types depending on env, e.g., int or dict | |
class TaskSettableEnv(gym.Env): | |
""" | |
Extension of gym.Env to define a task-settable Env. | |
Your env must implement this interface in order to be used with MAML. | |
For curriculum learning, you can add this API to your env such that | |
the `env_task_fn` can set the next task as needed. | |
Supports: | |
- Sampling from a distribution of tasks for meta-learning. | |
- Setting the env to any task it supports. | |
- Getting the current task this env has been set to. | |
Examples: | |
>>> env = TaskSettableEnv(...) | |
>>> ... | |
>>> Trainer.workers.foreach_env(lambda base_env: base_env.my_prop) | |
""" | |
def sample_tasks(self, n_tasks: int) -> List[TaskType]: | |
"""Samples task of the meta-environment | |
Args: | |
n_tasks (int) : number of different meta-tasks needed | |
Returns: | |
tasks (list) : an (n_tasks) length list of tasks | |
""" | |
raise NotImplementedError | |
def set_task(self, task: TaskType) -> None: | |
"""Sets the specified task to the current environment | |
Args: | |
task: task of the meta-learning environment | |
""" | |
raise NotImplementedError | |
def get_task(self) -> TaskType: | |
"""Gets the task that the agent is performing in the current environment | |
Returns: | |
task: task of the meta-learning environment | |
""" | |
raise NotImplementedError |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment