Last active
September 23, 2022 06:02
-
-
Save tomonori-masui/8b1ae405655a5e829337072969a47a97 to your computer and use it in GitHub Desktop.
Training and evaluation functions for node classifier
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| def train_node_classifier(model, graph, optimizer, criterion, n_epochs=200): | |
| for epoch in range(1, n_epochs + 1): | |
| model.train() | |
| optimizer.zero_grad() | |
| out = model(graph) | |
| loss = criterion(out[graph.train_mask], graph.y[graph.train_mask]) | |
| loss.backward() | |
| optimizer.step() | |
| pred = out.argmax(dim=1) | |
| acc = eval_node_classifier(model, graph, graph.val_mask) | |
| if epoch % 10 == 0: | |
| print(f'Epoch: {epoch:03d}, Train Loss: {loss:.3f}, Val Acc: {acc:.3f}') | |
| return model | |
| def eval_node_classifier(model, graph, mask): | |
| model.eval() | |
| pred = model(graph).argmax(dim=1) | |
| correct = (pred[mask] == graph.y[mask]).sum() | |
| acc = int(correct) / int(mask.sum()) | |
| return acc | |
| mlp = MLP().to(device) | |
| optimizer_mlp = torch.optim.Adam(mlp.parameters(), lr=0.01, weight_decay=5e-4) | |
| criterion = nn.CrossEntropyLoss() | |
| mlp = train_node_classifier(mlp, graph, optimizer_mlp, criterion, n_epochs=150) | |
| test_acc = eval_node_classifier(mlp, graph, graph.test_mask) | |
| print(f'Test Acc: {test_acc:.3f}') |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment