Skip to content

Instantly share code, notes, and snippets.

@rohithteja
Created August 21, 2021 14:07
Show Gist options
  • Save rohithteja/d378a1cdee32b7d55d79e9c9756de157 to your computer and use it in GitHub Desktop.
Save rohithteja/d378a1cdee32b7d55d79e9c9756de157 to your computer and use it in GitHub Desktop.
Spektral Custom Dataset
import torch
import networkx as nx
from spektral.data import Dataset
from sklearn.preprocessing import OneHotEncoder
from sklearn.model_selection import train_test_split
from spektral.transforms import AdjToSpTensor, LayerPreprocess
from spektral.layers import GCNConv
# spektral custom dataset class
class KarateDataset(Dataset):
def __init__(self, nodes, feats, **kwargs):
self.nodes = nodes
self.feats = feats
super().__init__(**kwargs)
def read(self):
output = []
A = nx.to_scipy_sparse_matrix(G)
Y = labels
le = OneHotEncoder()
YY = le.fit_transform(Y.reshape(-1,1))
output.append(
Graph(x=self.feats.astype("float32"),
a=A.astype("float32"),
y=YY.astype("float32").todense()))
return output
dataset = KarateDataset(nodes=np.array(list(G.nodes())),
feats=embeddings.numpy(),
transforms=[LayerPreprocess(GCNConv), AdjToSpTensor()])
data = dataset[0]
# create train and test masks
node = np.array(list(G.nodes()))
n_nodes = node.shape[0]
X_train, X_test, y_train, y_test = train_test_split(pd.Series(node),
pd.Series(labels),
test_size=0.30,
random_state=42)
train_mask = torch.zeros(n_nodes, dtype=torch.float32)
test_mask = torch.zeros(n_nodes, dtype=torch.float32)
train_mask[X_train.index] = 1
test_mask[X_test.index] = 1
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment