Skip to content

Instantly share code, notes, and snippets.

@tomonori-masui
Last active September 21, 2022 04:37
Show Gist options
  • Save tomonori-masui/423ea8f3d19d69595027591634ddaa84 to your computer and use it in GitHub Desktop.
Save tomonori-masui/423ea8f3d19d69595027591634ddaa84 to your computer and use it in GitHub Desktop.
GCN node classifier
from torch_geometric.nn import GCNConv
import torch.nn.functional as F
class GCN(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv1 = GCNConv(dataset.num_node_features, 16)
self.conv2 = GCNConv(16, dataset.num_classes)
def forward(self, data):
x, edge_index = data.x, data.edge_index
x = self.conv1(x, edge_index)
x = F.relu(x)
output = self.conv2(x, edge_index)
return output
gcn = GCN().to(device)
optimizer_gcn = torch.optim.Adam(gcn.parameters(), lr=0.01, weight_decay=5e-4)
criterion = nn.CrossEntropyLoss()
gcn = train_node_classifier(gcn, graph, optimizer_gcn, criterion)
test_acc = eval_node_classifier(gcn, graph, graph.test_mask)
print(f'Test Acc: {test_acc:.3f}')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment