Last active
April 24, 2020 16:15
-
-
Save abhayraw1/71ea2dedf8f343059a4b30e54e5ecc00 to your computer and use it in GitHub Desktop.
Generate Image Dataset of Pendulum-v0
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import gym | |
import pdb | |
import pickle | |
import argparse | |
from utils import * | |
from memory import * | |
BIT_DEPTH = 5 | |
def rollout(memory, env): | |
episode = Episode(memory.device, BIT_DEPTH) | |
x = env.reset() | |
for _ in range(env.env._max_episode_steps): | |
u = env.sample_random_action() | |
nx, r, d, _ = env.step(u) | |
episode.append(x, u, r, d) | |
x = nx | |
episode.append_last_obs(x) | |
memory.append(episode) | |
def main(env_name, path): | |
env = TorchImageEnvWrapper('Pendulum-v0', BIT_DEPTH) | |
memory = Memory(2000, None, 50) | |
for _ in range(2000): | |
rollout(memory, env) | |
env.close() | |
with open(path, 'wb+') as f: | |
memory = pickle.dump(memory, f) | |
print('DONE!!!') | |
print('Thanks. Now move it to Scratch and ping me!! :P') | |
if __name__ == '__main__': | |
parser = argparse.ArgumentParser(description='Visual Dataset Generator') | |
parser.add_argument('env', type=str, help='Name of Gym Environment.') | |
parser.add_argument( | |
'--output', type=str, default='memory.pth', help='Name of output file' | |
) | |
args = parser.parse_args() | |
main(args.env, args.output) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import pdb | |
import torch | |
import numpy as np | |
from utils import * | |
from collections import deque | |
from numpy.random import choice | |
from torch import float32 as F32 | |
from torch.nn.utils.rnn import pad_sequence | |
class Episode: | |
def __init__(self, device, bit_depth): | |
self.device = device | |
self.bit_depth = bit_depth | |
self.clear() | |
@property | |
def size(self): | |
return self._size | |
def clear(self): | |
self.x = [] | |
self.u = [] | |
self.d = [] | |
self.r = [] | |
self._size = 0 | |
def append(self, x, u, r, d): | |
self._size += 1 | |
self.x.append(postprocess_img(x.numpy(), self.bit_depth)) | |
self.u.append(u.numpy()) | |
self.r.append(r) | |
self.d.append(d) | |
def append_last_obs(self, x): | |
self.x.append(postprocess_img(x.numpy(), self.bit_depth)) | |
def prepare(self, s=0, e=None): | |
e = e or self.size | |
prossx = torch.tensor(self.x[s:e+1], dtype=F32, device=self.device) | |
preprocess_img(prossx, self.bit_depth), | |
return ( | |
prossx, | |
torch.tensor(self.u[s:e], dtype=F32, device=self.device), | |
torch.tensor(self.r[s:e], dtype=F32, device=self.device), | |
torch.tensor(self.d[s:e], dtype=F32, device=self.device), | |
) | |
class Memory: | |
def __init__(self, size, device, tracelen): | |
self.device = device | |
self._shapes = None | |
self.tracelen = tracelen | |
self.data = deque(maxlen=size) | |
self._empty_batch = None | |
@property | |
def size(self): | |
return len(self.data) | |
@property | |
def shapes(self): | |
return self._shapes | |
def get_empty_batch(self, batch_size): | |
if self._empty_batch is None or\ | |
self._empty_batch[0].size(0) != batch_size: | |
data = [] | |
for i, s in enumerate(self.shapes): | |
h = self.tracelen + 1 if not i else self.tracelen | |
data.append(torch.zeros(batch_size, h, *s).to(self.device)) | |
self._empty_batch = data | |
return [x.clone() for x in self._empty_batch] | |
def append(self, episode: Episode): | |
self.data.append(episode) | |
if self.shapes is None: | |
# Store the shapes of objects | |
self._shapes = [a.shape[1:] for a in episode.prepare(e=1)] | |
def sample(self, batch_size): | |
episode_idx = choice(self.size, batch_size) | |
init_st_idx = [choice(self.data[i].size) for i in episode_idx] | |
data = self.get_empty_batch(batch_size) | |
# xx, uu, rr, dd = [], [], [], [] | |
seq_lengths = [] | |
try: | |
for n, (i, s) in enumerate(zip(episode_idx, init_st_idx)): | |
x, u, r, d = self.data[i].prepare(s, s + self.tracelen) | |
data[0][n, :x.size(0)] = x | |
data[1][n, :u.size(0)] = u | |
data[2][n, :r.size(0)] = r | |
data[3][n, :d.size(0)] = d | |
seq_lengths.append(len(d)) | |
return data, seq_lengths | |
except: | |
pdb.set_trace() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import sys | |
import pdb | |
import cv2 | |
import gym | |
import torch | |
import numpy as np | |
from torchvision.utils import make_grid, save_image | |
def to_tensor_obs(image): | |
""" | |
Converts the input np img to channel first 64x64 dim torch img. | |
""" | |
image = cv2.resize(image, (64, 64), interpolation=cv2.INTER_LINEAR) | |
image = torch.tensor(image, dtype=torch.float32).permute(2, 0, 1) | |
return image | |
def postprocess_img(image, depth): | |
""" | |
Postprocess an image observation for storage. | |
From float32 numpy array [-0.5, 0.5] to uint8 numpy array [0, 255]) | |
""" | |
image = np.floor((image + 0.5) * 2 ** depth) | |
return np.clip(image * 2**(8 - depth), 0, 2**8 - 1).astype(np.uint8) | |
def preprocess_img(image, depth): | |
""" | |
Preprocesses an observation inplace. | |
From float32 Tensor [0, 255] to [-0.5, 0.5] | |
""" | |
image.div_(2 ** (8 - depth)).floor_().div_(2 ** depth).sub_(0.5) | |
image.add_(torch.rand_like(image).div_(2 ** depth)) | |
def get_combined_params(*models): | |
""" | |
Returns the combine parameter list of all the models given as input. | |
""" | |
params = [] | |
for model in models: | |
params.extend(list(model.parameters())) | |
return params | |
def save_frames(target, pred_prior, pred_posterior, name, n_rows=5): | |
""" | |
Saves the target images with the generated prior and posterior predictions. | |
""" | |
image = torch.cat([target, pred_prior, pred_posterior], dim=3) | |
save_image(make_grid(image + 0.5, nrow=n_rows), f'{name}.png') | |
def get_mask(tensor, lengths): | |
""" | |
Generates the masks for batches of sequences. | |
Time should be the first axis. | |
input: | |
tensor: the tensor for which to generate the mask [N x T x ...] | |
lengths: lengths of the seq. [N] | |
""" | |
mask = torch.zeros_like(tensor) | |
for i in range(len(lengths)): | |
mask[i, :lengths[i]] = 1. | |
return mask | |
# def | |
def apply_model(model, inputs, ignore_dim=None): | |
pass | |
class TorchImageEnvWrapper: | |
""" | |
Torch Env Wrapper that wraps a gym env and makes interactions using Tensors. | |
Also returns observations in image form. | |
""" | |
def __init__(self, env, bit_depth, observation_shape=None): | |
self.env = gym.make(env) | |
self.bit_depth = bit_depth | |
def reset(self): | |
self.env.reset() | |
x = to_tensor_obs(self.env.render(mode='rgb_array')) | |
preprocess_img(x, self.bit_depth) | |
return x | |
def step(self, u): | |
_, r, d, i = self.env.step(u.detach().numpy()) | |
x = to_tensor_obs(self.env.render(mode='rgb_array')) | |
preprocess_img(x, self.bit_depth) | |
return x, r, d, i | |
def render(self): | |
self.env.render() | |
def close(self): | |
self.env.close() | |
@property | |
def observation_size(self): | |
return (3, 64, 64) | |
@property | |
def action_size(self): | |
return self.env.action_space.shape[0] | |
def sample_random_action(self): | |
return torch.tensor(self.env.action_space.sample()) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment