Skip to content

Instantly share code, notes, and snippets.

@lgray
Created October 9, 2020 13:40
Show Gist options
  • Select an option

  • Save lgray/bae7bb41b2227c096a082e593ab4516b to your computer and use it in GitHub Desktop.

Select an option

Save lgray/bae7bb41b2227c096a082e593ab4516b to your computer and use it in GitHub Desktop.
class GraphWeightsNetwork(nn.Module):
def __init__ (self, continuous_dim, cat_dim, output_dim=1, hidden_dim=32, conv_depth=1):
super(GraphMETNetwork, self).__init__()
self.embed_charge = nn.Embedding(3, hidden_dim//4)
self.embed_pdgid = nn.Embedding(7, hidden_dim//4)
self.embed_pv = nn.Embedding(8, hidden_dim//4)
self.embed_continuous = nn.Sequential(nn.Linear(continuous_dim,hidden_dim//2),
nn.ELU(),
# nn.BatchNorm1d(hidden_dim) # uncomment if it starts overtraining
)
self.embed_categorical = nn.Sequential(nn.Linear(3*hidden_dim//4,hidden_dim//2),
nn.ELU(),
# nn.BatchNorm1d(hidden_dim)
)
self.conv_continuous = nn.ModuleList()
for i in range(conv_depth):
mesg = nn.Sequential(nn.Linear(2*hidden_dim, hidden_dim),
nn.ELU(),
# nn.BatchNorm1d(hidden_dim)
)
self.conv_continuous.append(
EdgeConv(nn=mesg).jittable()
#GATConv(hidden_dim, hidden_dim).jittable()
#GCNConv(hidden_dim, hidden_dim).jittable()
#SGConv(hidden_dim, hidden_dim).jittable()
)
self.output = nn.Sequential(nn.Linear(hidden_dim, hidden_dim//2),
nn.ELU(),
nn.Linear(hidden_dim//2, output_dim)
)
def forward(self, x_cont, x_cat, edge_index, batch):
emb_cont = self.embed_continuous(x_cont)
emb_chrg = self.embed_charge(x_cat[:, 1] + 1)
emb_pdg = self.embed_pdgid(x_cat[:, 0])
emb_pv = self.embed_pv(x_cat[:, 2])
emb_cat = self.embed_categorical(torch.cat([emb_chrg, emb_pdg, emb_pv], dim=1))
emb = torch.cat([emb_cat, emb_cont], dim=1)
# graph convolution for continuous variables
for co_conv in self.conv_continuous:
#emb_cont = co_conv(emb_cont, edge_index)
emb = emb + co_conv(emb, edge_index) # residual connections on the convolutional layer
out = self.output(emb)
return out.squeeze(-1)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment