This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import math | |
import torch | |
import torch.linalg as linalg | |
def calculate_2_wasserstein_dist(X, Y): | |
''' | |
Calulates the two components of the 2-Wasserstein metric: | |
The general formula is given by: d(P_X, P_Y) = min_{X, Y} E[|X-Y|^2] | |
For multivariate gaussian distributed inputs z_X ~ MN(mu_X, cov_X) and z_Y ~ MN(mu_Y, cov_Y), |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from typing import Union, List, Optional | |
from collections.abc import Mapping, Sequence | |
import torch.utils.data | |
from torch.utils.data.dataloader import default_collate | |
from torch_geometric.data import Data, HeteroData, Dataset, Batch | |
from torch_geometric_temporal.signal import StaticGraphTemporalSignal as SGTS | |
from torch_geometric_temporal.signal import DynamicGraphTemporalSignal as DGTS | |
from torch_geometric_temporal.signal import DynamicGraphStaticSignal as DGSS |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from torch_geometric.data import Data as GraphData | |
# ... load training data | |
train_data = None | |
# uses the following DataLoader: https://gist.github.com/Flunzmas/5a5c8c8fd553609359704be3174db793 | |
data_loader = DataLoader(train_data, batch_size=32, shuffle=True, num_workers=4, drop_last=True) | |
for batch_idx, data in enumerate(data_loader): | |
for t, batch_at_timestep in enumerate(data): |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import math | |
import torch | |
""" | |
Differentiable dual quaternion distance metric in PyTorch. | |
Acknowledgements: | |
- Function q_mul(): https://github.com/facebookresearch/QuaterNet/blob/main/common/quaternion.py | |
- Other functions related to quaternions: re-implementations based on pip package "pyquaternion" | |
- Functions related to dual quaternions: re-implementations based on pip package "dual_quaternions" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def to_line_digraph(self, data: Data) -> Data: | |
""" | |
TODO can we make this more efficient by removing the for-loop? | |
""" | |
assert data.edge_index is not None | |
assert data.is_directed() | |
edge_index, edge_attr = data.edge_index, data.edge_attr | |
N, E = data.num_nodes, data.num_edges | |
edge_index, edge_attr = coalesce(edge_index, edge_attr, num_nodes=data.num_nodes) |