Skip to content

Instantly share code, notes, and snippets.

@mlabonne
Created September 20, 2022 10:59
Show Gist options
  • Save mlabonne/19ca2165969fb3a8418299524500420e to your computer and use it in GitHub Desktop.
Save mlabonne/19ca2165969fb3a8418299524500420e to your computer and use it in GitHub Desktop.
GAT architecture for graph classification with global_add_pool
import torch
import torch.nn.functional as F
from torch.nn import Linear, Sequential, BatchNorm1d, ReLU, Dropout
from torch_geometric.nn import GATConv
from torch_geometric.nn import global_add_pool
class GAT(torch.nn.Module):
def __init__(self, dim_h):
super(GAT, self).__init__()
self.conv1 = GATConv(dataset.num_node_features, dim_h)
self.conv2 = GATConv(dim_h, dim_h)
self.conv3 = GATConv(dim_h, dim_h)
self.lin = Linear(dim_h, dataset.num_classes)
def forward(self, x, edge_index, batch):
# Node embeddings
h = self.conv1(x, edge_index)
h = h.relu()
h = self.conv2(h, edge_index)
h = h.relu()
h = self.conv3(h, edge_index)
# Graph-level readout
hG = global_add_pool(h, batch)
# Classifier
h = F.dropout(hG, p=0.5, training=self.training)
h = self.lin(h)
return F.log_softmax(h, dim=1)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment