Last active
October 18, 2024 16:04
-
-
Save KruskalLin/f9f05b85e78d486e7630390d4d3fb3f1 to your computer and use it in GitHub Desktop.
Read the EXP dataset
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import pickle | |
import numpy as np | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import torch_geometric | |
from torch_geometric.nn import Linear | |
from torch_geometric.data import InMemoryDataset, Data | |
from minimal_frame.graph_group_decorator import undirected_weighted_sn_invariant_decorator, undirected_weighted_sn_equivariant_decorator | |
from torch_geometric.utils import to_dense_adj | |
class MLP(nn.Module): | |
def __init__(self, d): | |
super(MLP, self).__init__() | |
self.fc1 = Linear(d, 32) | |
self.fc2 = Linear(d, 32) | |
self.fc3 = Linear(32, 1) | |
self.fc4 = Linear(32, 1) | |
def forward(self, x): | |
x = F.relu(self.fc1(x)).T | |
x = F.relu(self.fc2(x)).T | |
x = self.fc3(x).T | |
return self.fc4(x) | |
def generate_Sn_matrix(n): | |
I = np.eye(n) | |
p = np.random.permutation(n) | |
P = I[p] | |
return torch.FloatTensor(P) | |
def pyg2_data_transform(data: Data): | |
""" | |
if we're on the new pyg (2.0 or later) and if the Data stored is in older format | |
we need to convert the data to the new format | |
""" | |
if torch_geometric.__version__ >= "2.0" and "_store" not in data.__dict__: | |
return Data( | |
**{k: v for k, v in data.__dict__.items() if v is not None} | |
) | |
return data | |
class GraphDataset(InMemoryDataset): | |
def __init__(self, root, transform=None, pre_transform=None, pre_filter=None): | |
super(GraphDataset, self).__init__(root, transform, pre_transform, pre_filter) | |
self.data, self.slices = torch.load(self.processed_paths[0]) | |
@property | |
def raw_file_names(self): | |
return ["GRAPHSAT.pkl"] | |
@property | |
def processed_file_names(self): | |
return 'data.pt' | |
def download(self): | |
pass | |
def process(self): | |
data_list = pickle.load(open(os.path.join(self.root, "GRAPHSAT.pkl"), "rb")) | |
data_list = [pyg2_data_transform(data) for data in data_list] | |
if self.pre_filter is not None: | |
data_list = [data for data in data_list if self.pre_filter(data)] | |
if self.pre_transform is not None: | |
data_list = [self.pre_transform(data) for data in data_list] | |
data, slices = self.collate(data_list) | |
torch.save((data, slices), self.processed_paths[0]) | |
if __name__ == '__main__': | |
dataset = GraphDataset("./") | |
num_nodes = dataset[0].x.size(0) | |
model = MLP(num_nodes) | |
model.forward = undirected_weighted_sn_invariant_decorator(model.forward) | |
x = dataset[0].x.squeeze(-1).float() | |
adj_matrix = to_dense_adj(dataset[0].edge_index, max_num_nodes=num_nodes).squeeze(0).float() | |
adj_matrix[torch.arange(num_nodes), torch.arange(num_nodes)] = x | |
# Isomorphism Test | |
y1 = model(adj_matrix) | |
S = generate_Sn_matrix(num_nodes) | |
y2 = model(S.T @ adj_matrix @ S) | |
print(y2 - y1) | |
# Non Isomorphism Test | |
x = dataset[1].x.squeeze(-1).float() | |
adj_matrix = to_dense_adj(dataset[0].edge_index, max_num_nodes=num_nodes).squeeze(0).float() | |
adj_matrix[torch.arange(num_nodes), torch.arange(num_nodes)] = x | |
y2 = model(adj_matrix) | |
print(y2 - y1) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment