Skip to content

Instantly share code, notes, and snippets.

View Flunzmas's full-sized avatar
🤗

Andreas Boltres Flunzmas

🤗
View GitHub Profile
@Flunzmas
Flunzmas / to_line_digraph.py
Created June 25, 2024 06:59
This function converts a PyTorch Geometric `Data` object representing a directed graph into its line digraph representation.
def to_line_digraph(self, data: Data) -> Data:
"""
TODO can we make this more efficient by removing the for-loop?
"""
assert data.edge_index is not None
assert data.is_directed()
edge_index, edge_attr = data.edge_index, data.edge_attr
N, E = data.num_nodes, data.num_edges
edge_index, edge_attr = coalesce(edge_index, edge_attr, num_nodes=data.num_nodes)
@Flunzmas
Flunzmas / dual_quaternion_distance.py
Created October 11, 2021 08:57
Differentiable dual quaternion distance metric in PyTorch
import math
import torch
"""
Differentiable dual quaternion distance metric in PyTorch.
Acknowledgements:
- Function q_mul(): https://github.com/facebookresearch/QuaterNet/blob/main/common/quaternion.py
- Other functions related to quaternions: re-implementations based on pip package "pyquaternion"
- Functions related to dual quaternions: re-implementations based on pip package "dual_quaternions"
@Flunzmas
Flunzmas / direct_access.py
Last active September 27, 2021 12:49
PyG: Access individual graphs from a Batch object not created through Batch.from_data_list()
from torch_geometric.data import Data as GraphData
# ... load training data
train_data = None
# uses the following DataLoader: https://gist.github.com/Flunzmas/5a5c8c8fd553609359704be3174db793
data_loader = DataLoader(train_data, batch_size=32, shuffle=True, num_workers=4, drop_last=True)
for batch_idx, data in enumerate(data_loader):
for t, batch_at_timestep in enumerate(data):
@Flunzmas
Flunzmas / pygt_loader.py
Created September 27, 2021 12:25
DataLoader for pytorch-geometric-temporal (direct extension of the loader from pytorch-geometric)
from typing import Union, List, Optional
from collections.abc import Mapping, Sequence
import torch.utils.data
from torch.utils.data.dataloader import default_collate
from torch_geometric.data import Data, HeteroData, Dataset, Batch
from torch_geometric_temporal.signal import StaticGraphTemporalSignal as SGTS
from torch_geometric_temporal.signal import DynamicGraphTemporalSignal as DGTS
from torch_geometric_temporal.signal import DynamicGraphStaticSignal as DGSS
@Flunzmas
Flunzmas / calc_2_wasserstein_dist.py
Last active October 17, 2024 01:31
Differentiable 2-Wasserstein Distance in PyTorch
import math
import torch
import torch.linalg as linalg
def calculate_2_wasserstein_dist(X, Y):
'''
Calulates the two components of the 2-Wasserstein metric:
The general formula is given by: d(P_X, P_Y) = min_{X, Y} E[|X-Y|^2]
For multivariate gaussian distributed inputs z_X ~ MN(mu_X, cov_X) and z_Y ~ MN(mu_Y, cov_Y),