Last active
March 19, 2019 05:00
-
-
Save vwxyzjn/bb2075e55171106e4a5691f35f25d504 to your computer and use it in GitHub Desktop.
How Openai's baselines handles different types of observation spaces and action spaces
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# https://github.com/hill-a/stable-baselines/blob/06f5843a3254ab7c2f6c927792e00365a778009e/stable_baselines/common/input.py#L6 | |
def observation_input(ob_space, batch_size=None, name='Ob', scale=False): | |
""" | |
Build observation input with encoding depending on the observation space type | |
When using Box ob_space, the input will be normalized between [1, 0] on the bounds ob_space.low and ob_space.high. | |
:param ob_space: (Gym Space) The observation space | |
:param batch_size: (int) batch size for input | |
(default is None, so that resulting input placeholder can take tensors with any batch size) | |
:param name: (str) tensorflow variable name for input placeholder | |
:param scale: (bool) whether or not to scale the input | |
:return: (TensorFlow Tensor, TensorFlow Tensor) input_placeholder, processed_input_tensor | |
""" | |
if isinstance(ob_space, Discrete): | |
observation_ph = tf.placeholder(shape=(batch_size,), dtype=tf.int32, name=name) | |
processed_observations = tf.to_float(tf.one_hot(observation_ph, ob_space.n)) | |
return observation_ph, processed_observations | |
elif isinstance(ob_space, Box): | |
observation_ph = tf.placeholder(shape=(batch_size,) + ob_space.shape, dtype=ob_space.dtype, name=name) | |
processed_observations = tf.to_float(observation_ph) | |
# rescale to [1, 0] if the bounds are defined | |
if (scale and | |
not np.any(np.isinf(ob_space.low)) and not np.any(np.isinf(ob_space.high)) and | |
np.any((ob_space.high - ob_space.low) != 0)): | |
# equivalent to processed_observations / 255.0 when bounds are set to [255, 0] | |
processed_observations = ((processed_observations - ob_space.low) / (ob_space.high - ob_space.low)) | |
return observation_ph, processed_observations | |
elif isinstance(ob_space, MultiBinary): | |
observation_ph = tf.placeholder(shape=(batch_size, ob_space.n), dtype=tf.int32, name=name) | |
processed_observations = tf.to_float(observation_ph) | |
return observation_ph, processed_observations | |
elif isinstance(ob_space, MultiDiscrete): | |
observation_ph = tf.placeholder(shape=(batch_size, len(ob_space.nvec)), dtype=tf.int32, name=name) | |
processed_observations = tf.concat([ | |
tf.to_float(tf.one_hot(input_split, ob_space.nvec[i])) for i, input_split | |
in enumerate(tf.split(observation_ph, len(ob_space.nvec), axis=-1)) | |
], axis=-1) | |
return observation_ph, processed_observations | |
else: | |
raise NotImplementedError("Error: the model does not support input space of type {}".format( | |
type(ob_space).__name__)) | |
# https://github.com/hill-a/stable-baselines/blob/06f5843a3254ab7c2f6c927792e00365a778009e/stable_baselines/common/distributions.py#L470 | |
def make_proba_dist_type(ac_space): | |
""" | |
return an instance of ProbabilityDistributionType for the correct type of action space | |
:param ac_space: (Gym Space) the input action space | |
:return: (ProbabilityDistributionType) the approriate instance of a ProbabilityDistributionType | |
""" | |
if isinstance(ac_space, spaces.Box): | |
assert len(ac_space.shape) == 1, "Error: the action space must be a vector" | |
return DiagGaussianProbabilityDistributionType(ac_space.shape[0]) | |
elif isinstance(ac_space, spaces.Discrete): | |
return CategoricalProbabilityDistributionType(ac_space.n) | |
elif isinstance(ac_space, spaces.MultiDiscrete): | |
return MultiCategoricalProbabilityDistributionType(ac_space.nvec) | |
elif isinstance(ac_space, spaces.MultiBinary): | |
return BernoulliProbabilityDistributionType(ac_space.n) | |
else: | |
raise NotImplementedError("Error: probability distribution, not implemented for action space of type {}." | |
.format(type(ac_space)) + | |
" Must be of type Gym Spaces: Box, Discrete, MultiDiscrete or MultiBinary.") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment