Skip to content

Instantly share code, notes, and snippets.

@Forbu
Created February 6, 2023 13:27
Show Gist options
  • Save Forbu/e22e618524661f5ae86b525244c55c8f to your computer and use it in GitHub Desktop.
Save Forbu/e22e618524661f5ae86b525244c55c8f to your computer and use it in GitHub Desktop.
from torch_geometric.nn.conv import MessagePassing
class MPGNNConv(MessagePassing):
def __init__(self, node_dim, edge_dim, layers=3):
super().__init__(aggr='mean', node_dim=0)
self.lin_edge = MLP(in_dim=node_dim * 2 + edge_dim, out_dim=node_dim, hidden_layers=layers)
self.lin_node = MLP(in_dim=node_dim * 2, out_dim=node_dim, hidden_layers=layers)
def forward(self, x, edge_index, edge_attr):
"""
here we apply the message passing function
and then we apply the MLPs to the output of the message passing function
"""
# message passing
message_info = self.propagate(edge_index, x=x, edge_attr=edge_attr)
# we concat the output of the message passing function with the input node features
x = torch.cat((x, message_info), dim=-1)
# now we apply the MLPs
x = self.lin_node(x)
return x, edge_attr
def message(self, x_j: Tensor, x_i: Tensor, edge_attr: Tensor):
x = torch.cat((x_i, x_j, edge_attr), dim=-1)
x = self.lin_edge(x)
return x
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment