This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| env = gym.make('MineRLObtainDiamond-v0') | |
| env.seed(21) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # Define the sequence of actions | |
| script = ['forward'] * 5 + [''] * 40 | |
| env = gym.make('MineRLObtainDiamond-v0') | |
| env = Recorder(env, './video', fps=60) | |
| env.seed(21) | |
| obs = env.reset() | |
| for action in script: | |
| # Get the action space (dict of possible actions) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| script = [] | |
| script += [''] * 20 | |
| script += ['forward'] * 5 | |
| script += ['attack'] * 61 | |
| script += ['camera:[-10,0]'] * 7 # Look up | |
| script += ['attack'] * 240 | |
| script += ['jump'] | |
| script += ['forward'] * 10 # Jump forward | |
| script += ['camera:[-10,0]'] * 2 # Look up | |
| script += ['attack'] * 150 |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| class CNN(nn.Module): | |
| def __init__(self, input_shape, output_dim): | |
| super().__init__() | |
| n_input_channels = input_shape[0] | |
| self.cnn = nn.Sequential( | |
| nn.Conv2d(n_input_channels, 32, kernel_size=8, stride=4), | |
| nn.BatchNorm2d(32), | |
| nn.ReLU(), | |
| nn.Conv2d(32, 64, kernel_size=4, stride=2), | |
| nn.BatchNorm2d(64), |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # Get data | |
| minerl.data.download(directory='data', environment='MineRLTreechop-v0') | |
| data = minerl.data.make("MineRLTreechop-v0", data_dir='data', num_workers=2) | |
| # Model | |
| model = CNN((3, 64, 64), 7).cuda() | |
| optimizer = torch.optim.Adam(model.parameters(), lr=0.0001) | |
| criterion = nn.CrossEntropyLoss() | |
| # Training loop |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| model = CNN((3, 64, 64), 7).cuda() | |
| model.load_state_dict(torch.load('model.pth')) | |
| env = gym.make('MineRLObtainDiamond-v0') | |
| env1 = Recorder(env, './video', fps=60) | |
| env = ActionShaping(env1) | |
| action_list = np.arange(env.action_space.n) | |
| obs = env.reset() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| obs = env_script.reset() | |
| done = False | |
| # 1. Get wood with the CNN | |
| for i in tqdm(range(3000)): | |
| obs = torch.from_numpy(obs['pov'].transpose(2, 0, 1)[None].astype(np.float32) / 255).cuda() | |
| probabilities = torch.softmax(model(obs), dim=1)[0].detach().cpu().numpy() | |
| action = np.random.choice(action_list, p=probabilities) | |
| obs, reward, done, _ = env_script.step(action) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import torch | |
| import torch.nn.functional as F | |
| from torch.nn import Linear, Sequential, BatchNorm1d, ReLU, Dropout | |
| from torch_geometric.nn import GATConv | |
| from torch_geometric.nn import global_add_pool | |
| class GAT(torch.nn.Module): | |
| def __init__(self, dim_h): | |
| super(GAT, self).__init__() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import numpy as np | |
| import matplotlib.pyplot as plt | |
| def softmax(x, temperature=1.0): | |
| e_x = np.exp(x / temperature) | |
| return e_x / e_x.sum(axis=0) | |
| logits = np.array([1.5, -1.8, 0.9, -3.2]) | |
| temperatures = [1.0, 0.5, 0.1] |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.