Skip to content

Instantly share code, notes, and snippets.

@rohithteja
Created August 14, 2021 15:31
Show Gist options
  • Save rohithteja/02299846a6b815e4dfd644cdc170a943 to your computer and use it in GitHub Desktop.
Save rohithteja/02299846a6b815e4dfd644cdc170a943 to your computer and use it in GitHub Desktop.
Custom Dataset PyG
import torch
import pandas as pd
from torch_geometric.data import InMemoryDataset, Data
from sklearn.model_selection import train_test_split
import torch_geometric.transforms as T
# custom dataset
class KarateDataset(InMemoryDataset):
def __init__(self, transform=None):
super(KarateDataset, self).__init__('.', transform, None, None)
data = Data(edge_index=edge_index)
data.num_nodes = G.number_of_nodes()
# embedding
data.x = embeddings.type(torch.float32)
# labels
y = torch.from_numpy(labels).type(torch.long)
data.y = y.clone().detach()
data.num_classes = 2
# splitting the data into train, validation and test
X_train, X_test, y_train, y_test = train_test_split(pd.Series(G.nodes()),
pd.Series(labels),
test_size=0.30,
random_state=42)
n_nodes = G.number_of_nodes()
# create train and test masks for data
train_mask = torch.zeros(n_nodes, dtype=torch.bool)
test_mask = torch.zeros(n_nodes, dtype=torch.bool)
train_mask[X_train.index] = True
test_mask[X_test.index] = True
data['train_mask'] = train_mask
data['test_mask'] = test_mask
self.data, self.slices = self.collate([data])
def _download(self):
return
def _process(self):
return
def __repr__(self):
return '{}()'.format(self.__class__.__name__)
dataset = KarateDataset()
data = dataset[0]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment