Skip to content

Instantly share code, notes, and snippets.

View zh4ngx's full-sized avatar

Andy Zhang zh4ngx

  • San Francisco, CA
View GitHub Profile
@zh4ngx
zh4ngx / cart_pole_cem_3.py
Created July 4, 2017 01:36
CartPole-v0 Cross Entropy Method with Minimal Params
# Source: http://rl-gym-doc.s3-website-us-west-2.amazonaws.com/mlss/lab1.html
import gym
import numpy as np
from gym.spaces import Discrete, Box
from gym.wrappers.monitoring import Monitor
# ================================================================
# Policies
# ================================================================
@zh4ngx
zh4ngx / cross_entropy.py
Created July 4, 2017 04:31
Cross Entropy (Evolutionary Strategy) on CartPole-v0 - somewhat overparameterized
# Source: http://rl-gym-doc.s3-website-us-west-2.amazonaws.com/mlss/lab1.html
import gym
import numpy as np
from gym.wrappers.monitoring import Monitor
from evaluation import noisy_evaluation, do_episode
from utils import get_dim_theta, make_policy
# Task settings:
env = gym.make('CartPole-v0') # Change as needed
@zh4ngx
zh4ngx / evaluation.py
Created July 4, 2017 04:58
Monte Carlo EM - weighted sampling of mean/variance of theta by reward
from utils import make_policy
def do_episode(policy, env, max_steps, render=False):
total_rew = 0
ob = env.reset()
for t in range(max_steps):
a = policy.act(ob)
(ob, reward, done, _info) = env.step(a)
total_rew += reward
@zh4ngx
zh4ngx / evaluation.py
Created July 4, 2017 05:21
Monte Carlo EM CartPole-v0 with exponentially weighted variance
from utils import make_policy
def do_episode(policy, env, max_steps, render=False):
total_rew = 0
ob = env.reset()
for t in range(max_steps):
a = policy.act(ob)
(ob, reward, done, _info) = env.step(a)
total_rew += reward
@zh4ngx
zh4ngx / cross_entropy.py
Created July 4, 2017 19:17
Cleaned up CartPole
# Source: http://rl-gym-doc.s3-website-us-west-2.amazonaws.com/mlss/lab1.html
import gym
import numpy as np
from gym.wrappers.monitoring import Monitor
from policy import Policy
# Task settings:
env = gym.make('CartPole-v0') # Change as needed
env = Monitor(env, 'tmp/cart-pole-cross-entropy-1', force=True)