Skip to content

Instantly share code, notes, and snippets.

@emptymalei
Created January 16, 2022 08:48
Show Gist options
  • Save emptymalei/0269c60796262172256ce588c93734c0 to your computer and use it in GitHub Desktop.
Save emptymalei/0269c60796262172256ce588c93734c0 to your computer and use it in GitHub Desktop.
stemgnn_experiment
import torch
from torch.utils.data import Dataset
class FakeTimeSeriesDataset(Dataset):
def __init__(self, sequence_length, input_length, prediction_length, nodes) -> None:
super().__init__()
self.sequence_length = sequence_length
self.prediction_length = prediction_length
self.input_length = input_length
assert self.sequence_length > self.prediction_length
self.nodes = nodes
self._gen_data()
def _gen_data(self):
series = []
for i in range(self.nodes):
series.append(
torch.sin(
torch.linspace(
i, self.sequence_length + i, self.sequence_length + 1
)
)
)
self.data = torch.stack(series)
def _slice(self, index):
return (
self.data[:, index : index + self.input_length],
self.data[
:,
index
+ self.input_length : index
+ self.input_length
+ self.prediction_length,
],
)
def __len__(self):
return self.sequence_length - self.input_length
def __getitem__(self, index):
return self._slice(index)
if __name__ == "__main__":
d = FakeTimeSeriesDataset(100, 5, 1, 3)
print(d.data.shape)
print(d[0][0].shape, d[0][1].shape)
print(d[0])
from stemgnn import Model
from data import FakeTimeSeriesDataset
import pytorch_lightning as pl
from torch import nn
import torch
from loguru import logger
from pytorch_lightning.loggers import TensorBoardLogger
class SGNN(pl.LightningModule):
def __init__(
self, nodes, stemgnn_stacks=2, window_size=12, multi_layer=5, horizon=3
):
super().__init__()
self.nodes = nodes
self.stemgnn_stacks = stemgnn_stacks
self.window_size = window_size
self.multi_layer = multi_layer
self.horizon = horizon
self.model = Model(
self.nodes,
self.stemgnn_stacks,
self.window_size,
self.multi_layer,
self.horizon,
)
self.loss_function = nn.MSELoss(reduction="mean")
self.example_input_array = torch.rand((2, self.nodes, self.window_size))
def forward(self, x):
forecast, _ = self.model(x)
return forecast
def training_step(self, batch, batch_idx):
x, y = batch
forecast, _ = self.model(x)
loss = self.loss_function(forecast, y)
self.log("train_loss", loss)
return loss
def configure_optimizers(self):
optimizer = torch.optim.Adam(
params=self.parameters(), lr=1e-3, betas=(0.9, 0.999)
)
return optimizer
if __name__ == "__main__":
dataset = FakeTimeSeriesDataset(
sequence_length=100, input_length=5, prediction_length=1, nodes=3
)
logger.debug(f"Loaded dataset: {dataset[0]}")
dl = torch.utils.data.DataLoader(dataset, batch_size=2)
m = SGNN(
nodes=dataset.nodes,
window_size=dataset.input_length,
horizon=dataset.prediction_length,
)
logger = TensorBoardLogger("tb_logs", name="StemGNN", log_graph=True)
trainer = pl.Trainer()
trainer.fit(m, dl)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment