Skip to content

Instantly share code, notes, and snippets.

@mlabonne
Last active May 25, 2022 13:02
Show Gist options
  • Save mlabonne/be636a5cb4c1249b084ff46afe26a853 to your computer and use it in GitHub Desktop.
Save mlabonne/be636a5cb4c1249b084ff46afe26a853 to your computer and use it in GitHub Desktop.
class CNN(nn.Module):
def __init__(self, input_shape, output_dim):
super().__init__()
n_input_channels = input_shape[0]
self.cnn = nn.Sequential(
nn.Conv2d(n_input_channels, 32, kernel_size=8, stride=4),
nn.BatchNorm2d(32),
nn.ReLU(),
nn.Conv2d(32, 64, kernel_size=4, stride=2),
nn.BatchNorm2d(64),
nn.ReLU(),
nn.Conv2d(64, 64, kernel_size=3, stride=1),
nn.BatchNorm2d(64),
nn.ReLU(),
nn.Flatten(),
nn.Linear(1024, 512),
nn.ReLU(),
nn.Linear(512, output_dim)
)
def forward(self, observations):
return self.cnn(observations)
def dataset_action_batch_to_actions(dataset_actions, camera_margin=5):
...
class ActionShaping(gym.ActionWrapper):
...
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment