Created
February 6, 2023 13:27
-
-
Save Forbu/e22e618524661f5ae86b525244c55c8f to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from torch_geometric.nn.conv import MessagePassing | |
class MPGNNConv(MessagePassing): | |
def __init__(self, node_dim, edge_dim, layers=3): | |
super().__init__(aggr='mean', node_dim=0) | |
self.lin_edge = MLP(in_dim=node_dim * 2 + edge_dim, out_dim=node_dim, hidden_layers=layers) | |
self.lin_node = MLP(in_dim=node_dim * 2, out_dim=node_dim, hidden_layers=layers) | |
def forward(self, x, edge_index, edge_attr): | |
""" | |
here we apply the message passing function | |
and then we apply the MLPs to the output of the message passing function | |
""" | |
# message passing | |
message_info = self.propagate(edge_index, x=x, edge_attr=edge_attr) | |
# we concat the output of the message passing function with the input node features | |
x = torch.cat((x, message_info), dim=-1) | |
# now we apply the MLPs | |
x = self.lin_node(x) | |
return x, edge_attr | |
def message(self, x_j: Tensor, x_i: Tensor, edge_attr: Tensor): | |
x = torch.cat((x_i, x_j, edge_attr), dim=-1) | |
x = self.lin_edge(x) | |
return x |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment