Created
August 1, 2023 07:56
-
-
Save cloneofsimo/46cee71557b33f0a5b3c97cd99f73245 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from rg2.gym import Rg2UEnv, WalkerEnvConfig | |
from gym.wrappers import TimeLimit | |
from stable_baselines3 import PPO, SAC | |
from stable_baselines3.common.monitor import Monitor | |
from stable_baselines3.common.vec_env import VecNormalize, DummyVecEnv, SubprocVecEnv | |
from stable_baselines3.common.callbacks import CheckpointCallback | |
if __name__ == "__main__": | |
NOM_POS = [0, 0, 0.01, 1.0, 0.0, 0.0, 0.0] | |
def make_env(rank): | |
def __init(): | |
env = Rg2UEnv( | |
WalkerEnvConfig( | |
resource_dir="/home/user/Downloads/wcollision/urdf/robot.urdf", | |
gc_init=NOM_POS + [0.0] * 15, | |
gv_init=[0.0] * (15 + 6), | |
action_mean=[0.0] * 15, | |
action_std=[0.3] * 15, | |
p_gain=50, | |
d_gain=0.2, | |
env_params=[-1, -1, -1, -1, -1, -1], | |
).get_cpp_object(), | |
seed=rank, | |
visualizable=(rank == 0), | |
) | |
env = TimeLimit(env, 400) | |
env = Monitor(env) | |
return env | |
return __init | |
envs = SubprocVecEnv([make_env(i) for i in range(8)]) | |
envs = VecNormalize(envs, norm_obs=True, norm_reward=True, clip_obs=10.0) | |
# envs.env_method("turn_on_visualization", indices=0) | |
checkpoint_callback = CheckpointCallback( | |
save_freq=10_000, | |
save_path="./logs/", | |
name_prefix="rl_model", | |
save_replay_buffer=True, | |
save_vecnormalize=True, | |
) | |
model = PPO("MlpPolicy", envs, verbose=1, batch_size=64, learning_rate=2e-4) | |
model.learn(total_timesteps=100_000_000, callback=checkpoint_callback) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment