Skip to content

Instantly share code, notes, and snippets.

@KruskalLin
Last active October 18, 2024 16:04
Show Gist options
  • Save KruskalLin/f9f05b85e78d486e7630390d4d3fb3f1 to your computer and use it in GitHub Desktop.
Save KruskalLin/f9f05b85e78d486e7630390d4d3fb3f1 to your computer and use it in GitHub Desktop.
Read the EXP dataset
import os
import pickle
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch_geometric
from torch_geometric.nn import Linear
from torch_geometric.data import InMemoryDataset, Data
from minimal_frame.graph_group_decorator import undirected_weighted_sn_invariant_decorator, undirected_weighted_sn_equivariant_decorator
from torch_geometric.utils import to_dense_adj
class MLP(nn.Module):
def __init__(self, d):
super(MLP, self).__init__()
self.fc1 = Linear(d, 32)
self.fc2 = Linear(d, 32)
self.fc3 = Linear(32, 1)
self.fc4 = Linear(32, 1)
def forward(self, x):
x = F.relu(self.fc1(x)).T
x = F.relu(self.fc2(x)).T
x = self.fc3(x).T
return self.fc4(x)
def generate_Sn_matrix(n):
I = np.eye(n)
p = np.random.permutation(n)
P = I[p]
return torch.FloatTensor(P)
def pyg2_data_transform(data: Data):
"""
if we're on the new pyg (2.0 or later) and if the Data stored is in older format
we need to convert the data to the new format
"""
if torch_geometric.__version__ >= "2.0" and "_store" not in data.__dict__:
return Data(
**{k: v for k, v in data.__dict__.items() if v is not None}
)
return data
class GraphDataset(InMemoryDataset):
def __init__(self, root, transform=None, pre_transform=None, pre_filter=None):
super(GraphDataset, self).__init__(root, transform, pre_transform, pre_filter)
self.data, self.slices = torch.load(self.processed_paths[0])
@property
def raw_file_names(self):
return ["GRAPHSAT.pkl"]
@property
def processed_file_names(self):
return 'data.pt'
def download(self):
pass
def process(self):
data_list = pickle.load(open(os.path.join(self.root, "GRAPHSAT.pkl"), "rb"))
data_list = [pyg2_data_transform(data) for data in data_list]
if self.pre_filter is not None:
data_list = [data for data in data_list if self.pre_filter(data)]
if self.pre_transform is not None:
data_list = [self.pre_transform(data) for data in data_list]
data, slices = self.collate(data_list)
torch.save((data, slices), self.processed_paths[0])
if __name__ == '__main__':
dataset = GraphDataset("./")
num_nodes = dataset[0].x.size(0)
model = MLP(num_nodes)
model.forward = undirected_weighted_sn_invariant_decorator(model.forward)
x = dataset[0].x.squeeze(-1).float()
adj_matrix = to_dense_adj(dataset[0].edge_index, max_num_nodes=num_nodes).squeeze(0).float()
adj_matrix[torch.arange(num_nodes), torch.arange(num_nodes)] = x
# Isomorphism Test
y1 = model(adj_matrix)
S = generate_Sn_matrix(num_nodes)
y2 = model(S.T @ adj_matrix @ S)
print(y2 - y1)
# Non Isomorphism Test
x = dataset[1].x.squeeze(-1).float()
adj_matrix = to_dense_adj(dataset[0].edge_index, max_num_nodes=num_nodes).squeeze(0).float()
adj_matrix[torch.arange(num_nodes), torch.arange(num_nodes)] = x
y2 = model(adj_matrix)
print(y2 - y1)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment