Skip to content

Instantly share code, notes, and snippets.

@rohithteja
Created August 10, 2021 12:10
Show Gist options
  • Save rohithteja/a583d2e732985b857a322dfa599277ca to your computer and use it in GitHub Desktop.
Save rohithteja/a583d2e732985b857a322dfa599277ca to your computer and use it in GitHub Desktop.
Train GCN model
torch.manual_seed(42)
optimizer_name = "Adam"
lr = 1e-1
optimizer = getattr(torch.optim, optimizer_name)(model.parameters(), lr=lr)
epochs = 200
def train():
model.train()
optimizer.zero_grad()
F.nll_loss(model()[data.train_mask], data.y[data.train_mask]).backward()
optimizer.step()
@torch.no_grad()
def test():
model.eval()
logits = model()
mask1 = data['train_mask']
pred1 = logits[mask1].max(1)[1]
acc1 = pred1.eq(data.y[mask1]).sum().item() / mask1.sum().item()
mask = data['test_mask']
pred = logits[mask].max(1)[1]
acc = pred.eq(data.y[mask]).sum().item() / mask.sum().item()
return acc1,acc
for epoch in range(1, epochs):
train()
train_acc,test_acc = test()
print('#' * 70)
print('Train Accuracy: %s' %train_acc )
print('Test Accuracy: %s' % test_acc)
print('#' * 70)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment