Created
October 28, 2020 15:50
-
-
Save BenedictWilkins/d8ecc4c10cc5032ebf3022895e48506f to your computer and use it in GitHub Desktop.
OpenAI Gym observation slicing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import gym | |
class BoxSlice(gym.spaces.Box): | |
def __init__(self, box, s_ = np.s_[:,:,:]): | |
assert isinstance(box, gym.spaces.Box) | |
self.__box = box | |
self.__slice = s_ | |
super(BoxSlice, self).__init__(box.low, box.high, dtype=box.dtype) | |
def __getitem__(self, i): | |
return BoxSlice(self, i) | |
@property | |
def s_(self): | |
return self.__slice | |
@property | |
def low(self): | |
return self.__box.low[self.__slice] | |
@low.setter | |
def low(self, value): | |
pass | |
@property | |
def high(self): | |
return self.__box.high[self.__slice] | |
@high.setter | |
def high(self, value): | |
pass | |
@property | |
def shape(self): | |
return self.__box.low[self.__slice].shape | |
@shape.setter | |
def shape(self, value): | |
pass | |
class Slice(gym.Wrapper): | |
def __init__(self, env, s_=np.s_[:,:,:]): | |
super(Slice, self).__init__(env) | |
self.observation_space = BoxSlice(env.observation_space, s_=s_) | |
def __getitem__(self, i): | |
return Slice(self, s_=i) | |
def step(self, action, *args, **kwargs): | |
observation, *rest = self.env.step(action, *args, **kwargs) | |
observation = observation[self.observation_space.s_] | |
return (observation, *rest) | |
def reset(self, *args, **kwargs): | |
observation = self.env.reset(*args, **kwargs) | |
observation = observation[self.observation_space.s_] | |
return observation | |
if __name__ == "__main__": | |
env = gym.make("Pong-v0") | |
env = Slice(env) | |
print(env[1:].observation_space) # crop 1 pixel | |
print(env[10:20,::2,:].observation_space) # ?! | |
print(env.reset().shape) # normal observation | |
print(env[:,:,:1].reset().shape) # only red channel |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment