This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
model = Net(dataset.num_features, 128, 64).to(device) | |
optimizer = torch.optim.Adam(params=model.parameters(), lr=0.01) | |
criterion = torch.nn.BCEWithLogitsLoss() | |
model = train_link_predictor(model, train_data, val_data, optimizer, criterion) | |
test_auc = eval_link_predictor(model, test_data) | |
print(f"Test: {test_auc:.3f}") |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch_geometric.transforms as T | |
split = T.RandomNodeSplit(num_val=0.1, num_test=0.2) | |
graph = split(graph) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import random | |
from torch_geometric.utils import to_networkx | |
import networkx as nx | |
def convert_to_networkx(graph, n_sample=None): | |
g = to_networkx(graph, node_attrs=["x"]) | |
y = graph.y.numpy() | |
if n_sample is not None: |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
0: inlier | |
1: contextual outlier only | |
2: structural outlier only | |
3: both contextual outlier and structural outlier | |
""" | |
Counter(graph.y.tolist()) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch_geometric.transforms as T | |
split = T.RandomLinkSplit( | |
num_val=0.05, | |
num_test=0.1, | |
is_undirected=True, | |
add_negative_train_samples=False, | |
neg_sampling_ratio=1.0, | |
) | |
train_data, val_data, test_data = split(graph) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from pygod.models import DOMINANT | |
from sklearn.metrics import roc_auc_score, average_precision_score | |
def train_anomaly_detector(model, graph): | |
return model.fit(graph) | |
def eval_anomaly_detector(model, graph): | |
outlier_scores = model.decision_function(graph) | |
auc = roc_auc_score(graph.y.numpy(), outlier_scores) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from pygod.utils import load_data | |
graph = load_data('inj_cora') |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from sklearn.metrics import roc_auc_score | |
from torch_geometric.utils import negative_sampling | |
class Net(torch.nn.Module): | |
def __init__(self, in_channels, hidden_channels, out_channels): | |
super().__init__() | |
self.conv1 = GCNConv(in_channels, hidden_channels) | |
self.conv2 = GCNConv(hidden_channels, out_channels) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from torch_geometric.nn import GCNConv | |
import torch.nn.functional as F | |
class GCN(torch.nn.Module): | |
def __init__(self): | |
super().__init__() | |
self.conv1 = GCNConv(dataset.num_node_features, 16) | |
self.conv2 = GCNConv(16, dataset.num_classes) | |
def forward(self, data): |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from torch_geometric.datasets import Planetoid | |
dataset = Planetoid(root='/tmp/Cora', name='Cora') | |
graph = dataset[0] |
NewerOlder