Last active
October 12, 2021 03:50
-
-
Save ugo-nama-kun/cb852c7a2728f456037c4c2725048d37 to your computer and use it in GitHub Desktop.
強化学習の画像系のtips
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from dm_control import suite | |
from collections import deque | |
# deepmind control suite を使う場合 | |
env_dmc = suite.load("cartpole", "balance") | |
# dm_control で得た場合の画像の形式は (84, 84, 3).これは mujoco_py の場合も同じ | |
im = self.env_dmc.physics.render(camera_id=0, height=84, width=84) | |
# frame stack [(84, 84, 3), (84, 84, 3), (84, 84, 3)] | |
frames = deque(maxlen=3) | |
for _ in range(3): | |
frames.append(self.env_dmc.physics.render(camera_id=0, height=84, width=84)) | |
# get numpy array with (84, 84, 3*3) tensor (for Tensorflow) | |
obs = np.dstack(frames) | |
# get numpy array with (3*3, 84, 84) tensor (for Pytorch) | |
obs = np.dstack(frames).transpose(2, 0, 1) # torch の permute が numpy では transpose になる | |
# 表示 | |
import matplotlib.pyplot as plt | |
plt.figure() | |
for i in range(9): | |
plt.subplot(3, 3, 1+i) | |
plt.imshow(obs[i]) | |
plt.show() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Tensorflow のCNNの入力のテンソルは [batch_size, height, width, channels]
https://www.tensorflow.org/api_docs/python/tf/image
Pytorch のCNNの入力のテンソルは [batch_size, channels, height, width]
https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html#torch.nn.Conv2d