Skip to content

Instantly share code, notes, and snippets.

@lgray
Created August 7, 2020 19:03
Show Gist options
  • Select an option

  • Save lgray/3ce74b4d1c0a2ae9524076e8d2409738 to your computer and use it in GitHub Desktop.

Select an option

Save lgray/3ce74b4d1c0a2ae9524076e8d2409738 to your computer and use it in GitHub Desktop.
def forward(self, x, batch: OptTensor=None):
x = self.datanorm * x
x = self.inputnet(x)
for ec in self.edgeconvs:
edge_index = knn_graph(x, self.k, batch, loop=False, flow=ec.flow)
x = ec(x, edge_index)
out = self.output(x) # output the embedding directly from the dgcnn
# take the final graph and do edge classification on that
for ecc in self.edgecatconvs:
x = ecc(x, edge_index)
edge_cat = self.edge_classifier(torch.cat([x[edge_index[0]], x[edge_index[1]]], dim=1)).squeeze()
return out, edge_cat, edge_index
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment