Skip to content

Instantly share code, notes, and snippets.

@tomonori-masui
Last active September 23, 2022 06:02
Show Gist options
  • Select an option

  • Save tomonori-masui/8b1ae405655a5e829337072969a47a97 to your computer and use it in GitHub Desktop.

Select an option

Save tomonori-masui/8b1ae405655a5e829337072969a47a97 to your computer and use it in GitHub Desktop.
Training and evaluation functions for node classifier
def train_node_classifier(model, graph, optimizer, criterion, n_epochs=200):
for epoch in range(1, n_epochs + 1):
model.train()
optimizer.zero_grad()
out = model(graph)
loss = criterion(out[graph.train_mask], graph.y[graph.train_mask])
loss.backward()
optimizer.step()
pred = out.argmax(dim=1)
acc = eval_node_classifier(model, graph, graph.val_mask)
if epoch % 10 == 0:
print(f'Epoch: {epoch:03d}, Train Loss: {loss:.3f}, Val Acc: {acc:.3f}')
return model
def eval_node_classifier(model, graph, mask):
model.eval()
pred = model(graph).argmax(dim=1)
correct = (pred[mask] == graph.y[mask]).sum()
acc = int(correct) / int(mask.sum())
return acc
mlp = MLP().to(device)
optimizer_mlp = torch.optim.Adam(mlp.parameters(), lr=0.01, weight_decay=5e-4)
criterion = nn.CrossEntropyLoss()
mlp = train_node_classifier(mlp, graph, optimizer_mlp, criterion, n_epochs=150)
test_acc = eval_node_classifier(mlp, graph, graph.test_mask)
print(f'Test Acc: {test_acc:.3f}')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment