Created
May 9, 2024 02:41
-
-
Save tiandiao123/c8c28934bee2a76ad7e378f3851f36c7 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from torch import nn | |
import torch.nn.functional as F | |
from torch_geometric.data import Data | |
from torch_geometric.nn.conv import DNAConv | |
import numpy as np | |
class DynamicGraphNN(torch.nn.Module): | |
def __init__(self, in_channels, out_channels, num_layers = 2, heads=1, groups=1): | |
super().__init__() | |
self.num_layers = num_layers | |
self.conv_layer = DNAConv(channels = in_channels, heads=heads, groups=groups, dropout=0.1) | |
self.lin = torch.nn.Linear(in_channels, out_channels) | |
def forward(self, data): | |
x, edge_index, _ = data.x, data.edge_index, data.y | |
conv_layer_out = F.relu(self.conv_layer(x, edge_index)) | |
linear_out = self.lin(conv_layer_out) | |
return F.log_softmax(linear_out, dim=1) | |
def simulate_traffic_data(num_nodes, num_features, num_timesteps): | |
# Randomly generate node features for traffic data over time | |
features = torch.randn((num_nodes, num_timesteps, num_features)) | |
# Create a simple chain graph | |
edge_index = torch.tensor([[i, i+1] for i in range(num_nodes - 1)], dtype=torch.long).t() | |
# Prepare data for each timestep, accumulating historical features | |
data_list = [] | |
for t in range(1, num_timesteps + 1): | |
# Accumulate features from start to current timestep | |
current_features = features[:, :t, :] | |
labels = torch.randint(0, 2, (num_nodes,)) | |
# Ensure correct dimensions [num_nodes, num_timesteps (up to t), num_features] | |
data_list.append(Data(x=current_features, edge_index=edge_index, y = labels)) | |
return data_list | |
# Simulate data with 5 nodes, 4 features per node, over 3 timesteps | |
data_list = simulate_traffic_data(5, 4, 3) | |
def train(): | |
epochs = 10 | |
model = DynamicGraphNN(in_channels=4, out_channels=2) | |
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4) | |
criterion = nn.CrossEntropyLoss() # Using CrossEntropyLoss here | |
model.train() | |
for epoch in range(epochs): | |
total_loss = 0 | |
for data in data_list: # dataset is a list of Data objects for each timestep | |
optimizer.zero_grad() | |
out = model(data) | |
loss =criterion(out.squeeze(), data.y) # Assuming labels are present | |
loss.backward() | |
optimizer.step() | |
total_loss += loss.item() | |
print(f'Epoch {epoch}: Loss: {total_loss / len(data_list)}') | |
train() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment