Skip to content

Instantly share code, notes, and snippets.

Created September 27, 2021 12:25
Show Gist options
  • Save Flunzmas/5a5c8c8fd553609359704be3174db793 to your computer and use it in GitHub Desktop.
Save Flunzmas/5a5c8c8fd553609359704be3174db793 to your computer and use it in GitHub Desktop.
DataLoader for pytorch-geometric-temporal (direct extension of the loader from pytorch-geometric)
from typing import Union, List, Optional
from import Mapping, Sequence
from import default_collate
from import Data, HeteroData, Dataset, Batch
from torch_geometric_temporal.signal import StaticGraphTemporalSignal as SGTS
from torch_geometric_temporal.signal import DynamicGraphTemporalSignal as DGTS
from torch_geometric_temporal.signal import DynamicGraphStaticSignal as DGSS
from torch_geometric_temporal.signal import StaticGraphTemporalSignalBatch as SGTSBatch
from torch_geometric_temporal.signal import DynamicGraphTemporalSignalBatch as DGTSBatch
from torch_geometric_temporal.signal import DynamicGraphStaticSignalBatch as DGSSBatch
def collate_temporal_signal(
signal_list: List[Union[SGTS, DGTS, DGSS]],
signal_batch_class: Union[SGTSBatch, DGTSBatch, DGSSBatch],
follow_batch: Optional[Union[List[str]]] = None,
exclude_keys: Optional[Union[List[str]]] = None,
# check for inconsistent input signal types
if len(set([type(signal) for signal in signal_list])) > 1:
raise ValueError("Given Signal objects are of differing type!")
# determine output type
in_type = type(signal_list[0])
signal_batch_type = None
if in_type == SGTS:
signal_batch_type = SGTSBatch
elif in_type == DGTS:
signal_batch_type = DGTSBatch
elif in_type == DGSS:
signal_batch_type = DGSSBatch
# create list(samples) of list(temporal sequence) of torch_geometric.Data/HeteroData
all_graphs = [[g for g in iter(signal)] for signal in signal_list]
# check for inconsistent sequence lengths
if len(set([len(sample_list) for sample_list in all_graphs])) > 1:
raise ValueError("Batch Samples of differing sequence length are currently not supported.")
# create diagonalized torch_geometric.Batch objects by timestep
batches_by_timestep = [Batch.from_data_list([sample[t] for sample in all_graphs], follow_batch, exclude_keys)
for t in range(len(all_graphs[0]))]
# assemble signal of batched graphs
return signal_batch_type(
edge_indices = [batch["edge_index"].numpy() for batch in batches_by_timestep],
edge_weights = [batch["edge_attr"].numpy() for batch in batches_by_timestep],
features = [batch["x"].numpy() for batch in batches_by_timestep],
targets = [batch["y"].numpy() for batch in batches_by_timestep],
batches = [batch["batch"].numpy() for batch in batches_by_timestep]
class PYGTCollater(object):
def __init__(self, follow_batch, exclude_keys):
self.follow_batch = follow_batch
self.exclude_keys = exclude_keys
def collate(self, batch):
elem = batch[0]
if isinstance(elem, SGTS) or isinstance(elem, DGTS) or isinstance(elem, DGSS):
return collate_temporal_signal(batch, self.follow_batch, self.exclude_keys)
elif isinstance(elem, Data) or isinstance(elem, HeteroData):
return Batch.from_data_list(batch, self.follow_batch, self.exclude_keys)
elif isinstance(elem, torch.Tensor):
return default_collate(batch)
elif isinstance(elem, float):
return torch.tensor(batch, dtype=torch.float)
elif isinstance(elem, int):
return torch.tensor(batch)
elif isinstance(elem, str):
return batch
elif isinstance(elem, Mapping):
return {key: self.collate([d[key] for d in batch]) for key in elem}
elif isinstance(elem, tuple) and hasattr(elem, '_fields'):
return type(elem)(*(self.collate(s) for s in zip(*batch)))
elif isinstance(elem, Sequence) and not isinstance(elem, str):
return [self.collate(s) for s in zip(*batch)]
raise TypeError('DataLoader found invalid type: {}'.format(type(elem)))
def __call__(self, batch):
return self.collate(batch)
class DataLoader(
def __init__(
dataset: Union[Dataset, List[Data], List[HeteroData], List[SGTS], List[DGTS], List[DGSS]],
batch_size: int = 1,
shuffle: bool = False,
follow_batch: List[str] = [],
exclude_keys: List[str] = [],
if "collate_fn" in kwargs:
del kwargs["collate_fn"]
# Save for PyTorch Lightning...
self.follow_batch = follow_batch
self.exclude_keys = exclude_keys
super().__init__(dataset, batch_size, shuffle,
collate_fn=PYGTCollater(follow_batch, exclude_keys), **kwargs)
Copy link

Hi, I am trying to implement a Hetero Spatial-Temporal GNN which uses HeteroConv with two SAGEConv layers to generate the embedding of nodes in each snapshot, then concatenate the node embeddings from all snapshots, and use a custom 1D ResNet to predict the target value of a specific node type.
I am using StaticHeteroGraphTemporalSignal to convert several StaticHeteroGraphs created using Networkx (DiGraph) into HeteroData objects.
Can I get an example of how to use StaticHeteroGraphTemporalSignalBatch to create batches of 120 snapshots for a StaticHeteroGraph that contains 8760 snapshots in total?

Additionally, Is there a way to train using batches and multiple StaticHeteroGraphs when the edge_index_dict is different from Graph to Graph?

Copy link

thanks for sharing
Please, can you share input and output shapes?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment