Skip to content

Instantly share code, notes, and snippets.

@lgray
Created July 24, 2020 17:49
Show Gist options
  • Select an option

  • Save lgray/d1d624e207050e34e2f9b38ce966b87c to your computer and use it in GitHub Desktop.

Select an option

Save lgray/d1d624e207050e34e2f9b38ce966b87c to your computer and use it in GitHub Desktop.
import torch
import torch_geometric
from torch import nn
from torch_geometric.nn.conv import EdgeConv
class EdgeNetWithCategoriesJittable(nn.Module):
def __init__(self, input_dim=3, hidden_dim=8, output_dim=4, n_iters=1, aggr='add',
norm=torch.tensor([1./500., 1./500., 1./54., 1/25., 1./1000.])):
super(EdgeNetWithCategoriesJittable, self).__init__()
self.datanorm = nn.Parameter(norm)
start_width = 2 * (hidden_dim + input_dim)
middle_width = (3 * hidden_dim + 2*input_dim) // 2
self.n_iters = n_iters
self.inputnet = nn.Sequential(
nn.Linear(input_dim, 2*hidden_dim),
nn.Tanh(),
nn.Linear(2*hidden_dim, 2*hidden_dim),
nn.Tanh(),
nn.Linear(2*hidden_dim, hidden_dim),
nn.Tanh(),
)
self.edgenetwork = nn.Sequential(
nn.Linear(2*n_iters*hidden_dim, 2*hidden_dim),
nn.ELU(),
nn.Linear(2*hidden_dim, 2*hidden_dim),
nn.ELU(),
nn.Linear(2*hidden_dim, output_dim),
nn.LogSoftmax(dim=-1),
)
convnn = nn.Sequential(
nn.Linear(start_width, middle_width),
nn.ELU(),
#nn.Dropout(p=0.5, inplace=False),
nn.Linear(middle_width, hidden_dim),
nn.ELU()
)
self.firstnodenetwork = EdgeConv(nn=convnn, aggr=aggr).jittable()
self.nodenetwork = nn.ModuleList()
for i in range(n_iters - 1):
convnn = nn.Sequential(
nn.Linear(start_width, middle_width),
nn.ELU(),
#nn.Dropout(p=0.5, inplace=False),
nn.Linear(middle_width, hidden_dim),
nn.ELU()
)
self.nodenetwork.append(EdgeConv(nn=convnn, aggr=aggr).jittable())
def forward(self, x, edge_index):
row = edge_index[0]
col = edge_index[1]
x_norm = self.datanorm * x
H = self.inputnet(x_norm)
H = self.firstnodenetwork(torch.cat([H, x_norm], dim=-1), edge_index)
H_cat = H
for nodenetwork in self.nodenetwork:
H = nodenetwork(torch.cat([H, x_norm], dim=-1), edge_index)
H_cat = torch.cat([H, H_cat], dim=-1)
return self.edgenetwork(torch.cat([H_cat[row],H_cat[col]],dim=-1)).squeeze(-1)
test = EdgeNetWithCategoriesJittable(n_iters=6)
out = torch.jit.script(test)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment