Skip to content

Instantly share code, notes, and snippets.

@rohithteja
Created August 21, 2021 13:20
Show Gist options
  • Select an option

  • Save rohithteja/52a10958b88cca60b580770a5ec1e985 to your computer and use it in GitHub Desktop.

Select an option

Save rohithteja/52a10958b88cca60b580770a5ec1e985 to your computer and use it in GitHub Desktop.
DGL Custom Dataset
import pandas as pd
import numpy as np
import dgl
import torch
from dgl.data import DGLDataset
from sklearn.model_selection import train_test_split
# prepare the embeddings corresponding to each node
nodes = pd.DataFrame(list(H.nodes()))
nodes.columns = ['nodes']
nodes['embeddings'] = nodes['nodes'].map(embeddings)
embeddings = torch.from_numpy(np.stack(nodes.embeddings.values))
# custom dataset class
class KarateDataset(DGLDataset):
def __init__(self):
super().__init__(name='KarateDataset')
def process(self):
node_features = torch.from_numpy(np.stack(nodes.embeddings.values)).type(torch.float32)
node_labels = torch.from_numpy(labels).type(torch.long)
edges_src = torch.from_numpy(np.array([int(i[0]) for i in list(G.edges())])).type(torch.int32)
edges_dst = torch.from_numpy(np.array([int(i[1]) for i in list(G.edges())])).type(torch.int32)
self.graph = dgl.graph((edges_src, edges_dst), num_nodes=G.number_of_nodes())
self.graph.ndata['feat'] = node_features
self.graph.ndata['label'] = node_labels
# splitting the data into train and test
X_train, X_test, y_train, y_test = train_test_split(pd.Series(G.nodes()),
pd.Series(labels),
test_size=0.30,
random_state=42)
n_nodes = G.number_of_nodes()
# create train and test masks for data
train_mask = torch.zeros(n_nodes, dtype=torch.bool)
test_mask = torch.zeros(n_nodes, dtype=torch.bool)
train_mask[X_train.index] = True
test_mask[X_test.index] = True
self.graph.ndata['train_mask'] = train_mask
self.graph.ndata['test_mask'] = test_mask
def __getitem__(self, i):
return self.graph
def __len__(self):
return 1
data = KarateDataset()[0]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment