Skip to content

Instantly share code, notes, and snippets.

@khuangaf
Last active May 30, 2019 15:28
Show Gist options
  • Save khuangaf/a614a38c932a6c13a5f8b72623d2c12b to your computer and use it in GitHub Desktop.
Save khuangaf/a614a38c932a6c13a5f8b72623d2c12b to your computer and use it in GitHub Desktop.
def train():
model.train()
loss_all = 0
for data in train_loader:
data = data.to(device)
optimizer.zero_grad()
output = model(data)
label = data.y.to(device)
loss = crit(output, label)
loss.backward()
loss_all += data.num_graphs * loss.item()
optimizer.step()
return loss_all / len(train_dataset)
device = torch.device('cuda')
model = Net().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.005)
crit = torch.nn.BCELoss()
train_loader = DataLoader(train_dataset, batch_size=batch_size)
for epoch in range(num_epochs):
train()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment