Skip to content

Instantly share code, notes, and snippets.

@lgray
Created October 8, 2020 20:26
Show Gist options
  • Select an option

  • Save lgray/50e35dcb190a40a8f05ac8b29f10ad88 to your computer and use it in GitHub Desktop.

Select an option

Save lgray/50e35dcb190a40a8f05ac8b29f10ad88 to your computer and use it in GitHub Desktop.
graphmet with embeddings
class GraphMETNetwork(nn.Module):
def __init__ (self, continuous_dim, cat_dim, output_dim=1, hidden_dim=32, conv_depth=1):
super(GraphMETNetwork, self).__init__()
self.embed_charge = nn.Embedding(3, hidden_dim//4)
self.embed_pdgid = nn.Embedding(8, hidden_dim//4)
self.embed_pv = nn.Embedding(2, hidden_dim//4)
self.embed_continuous = nn.Sequential(nn.Linear(continuous_dim,hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
# nn.BatchNorm1d(hidden_dim) # uncomment if it starts overtraining
)
self.embed_categorical = nn.Sequential(nn.Linear(3*hidden_dim//4,hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
# nn.BatchNorm1d(hidden_dim)
)
self.conv_continuous = nn.ModuleList()
for i in range(conv_depth):
mesg = nn.Sequential(nn.Linear(2*hidden_dim, 3*hidden_dim//2),
nn.ReLU(),
nn.Linear(3*hidden_dim//2, hidden_dim),
# nn.BatchNorm1d(hidden_dim)
)
self.conv_continuous.append(
EdgeConv(nn=mesg).jittable()
#GCNConv(hidden_dim, hidden_dim).jittable()
)
self.conv_categorical = nn.ModuleList()
for i in range(conv_depth):
mesg = nn.Sequential(nn.Linear(2*hidden_dim, 3*hidden_dim//2),
nn.ReLU(),
nn.Linear(3*hidden_dim//2, hidden_dim),
# nn.BatchNorm1d(hidden_dim)
)
self.conv_categorical.append(
EdgeConv(nn=mesg).jittable()
#GCNConv(hidden_dim, hidden_dim).jittable()
)
self.output = nn.Sequential(nn.Linear(2*hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim//2),
nn.ReLU(),
nn.Linear(hidden_dim//2, output_dim)
)
def forward(self, x_cont, x_cat, edge_index, batch):
emb_cont = self.embed_continuous(x_cont)
emb_chrg = self.embed_charge(x_cat[:, 0] + 1)
emb_pdg = self.embed_pdgid(x_cat[:, 1])
emb_pv = self.embed_pv(x_cat[:, 2])
emb_cat = self.embed_categorical(torch.cat([emb_chrg, emb_pdg, emb_pv], dim=1))
# graph convolution for continuous variables
for co_conv in self.conv_continuous:
#emb_cont = co_conv(emb_cont, edge_index)
emb_cont = emb_cont + co_conv(emb_cont, edge_index)#residual connections on the convolutional layer
# graph convolution for discrete variables
for ca_conv in self.conv_categorical:
#emb_cat = ca_conv(emb_cat, edge_index)
emb_cat = emb_cat + ca_conv(emb_cat, edge_index)#residual connections on the convolutional layer
# concatenate embeddings together to make description of weight inputs
emb = torch.cat([emb_cont,emb_cat], dim=1)
out = self.output(emb)
return out.squeeze(-1)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment