Skip to content

Instantly share code, notes, and snippets.

@rohithteja
Created August 21, 2021 14:15
Show Gist options
  • Select an option

  • Save rohithteja/7f84dffb6e37e26edf9fabce3c730815 to your computer and use it in GitHub Desktop.

Select an option

Save rohithteja/7f84dffb6e37e26edf9fabce3c730815 to your computer and use it in GitHub Desktop.
DGL GCN Model
import torch
import torch.nn as nn
import torch.nn.functional as F
from dgl.nn import GraphConv
import dgl.function as fn
class GCN(nn.Module):
def __init__(self, in_feats, h_feats):
super(GCN, self).__init__()
self.conv1 = GraphConv(in_feats, h_feats)
self.conv2 = GraphConv(h_feats, h_feats)
def forward(self, g, in_feat):
h = self.conv1(g, in_feat)
h = F.relu(self.conv1(g, in_feat))
h = F.dropout(h, training=self.training)
h = self.conv2(g, h)
return F.log_softmax(h, dim=1)
node_features = data.ndata['feat']
node_labels = data.ndata['label']
train_mask = data.ndata['train_mask']
test_mask = data.ndata['test_mask']
n_features = node_features.shape[1]
# instantiate the model
model = GCN(n_features, 16)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment