Skip to content

Instantly share code, notes, and snippets.

@yzhangcs
Created July 7, 2020 14:21
Show Gist options
  • Save yzhangcs/c9816c6be87c612943fe064608cbef9b to your computer and use it in GitHub Desktop.
Save yzhangcs/c9816c6be87c612943fe064608cbef9b to your computer and use it in GitHub Desktop.
Graph Attention Networks
# -*- coding: utf-8 -*-
import torch
import torch.nn as nn
class GAT(nn.Module):
def __init__(self, n_input, n_inner, n_layers, alpha=0.1, dropout=0.5):
super(GAT, self).__init__()
self.n_input = n_input
self.n_inner = n_inner
self.n_layers = n_layers
self.alpha = alpha
self.dropout = dropout
self.layers = nn.ModuleList([GATLayer(n_input, n_inner,
alpha, dropout)
for _ in range(n_layers)])
def __repr__(self):
s = self.__class__.__name__ + '('
s += f"n_input={self.n_input}, "
s += f"n_inner={self.n_inner}, "
s += f"n_layers={self.n_layers}, "
s += f"alpha={self.alpha}, "
s += f"dropout={self.dropout}"
s += ')'
return s
def forward(self, x, mask):
for layer in self.layers:
x = layer(x, mask)
return x
class GATLayer(nn.Module):
def __init__(self, n_input, n_inner, alpha=0.1, dropout=0.5):
super(GATLayer, self).__init__()
self.n_input = n_input
self.n_inner = n_inner
self.alpha = alpha
self.w = nn.Linear(n_input, n_inner, False)
# for self
self.a1 = nn.Linear(n_inner, 1, False)
# for neighbours
self.a2 = nn.Linear(n_inner, 1, False)
self.leaky_relu = nn.LeakyReLU(self.alpha)
self.elu = nn.ELU()
self.dropout = nn.Dropout(dropout)
def __repr__(self):
s = self.__class__.__name__ + '('
s += f"n_input={self.n_input}, "
s += f"n_inner={self.n_inner}, "
s += f"alpha={self.alpha}, "
s += f"dropout={self.dropout.p}"
s += ')'
return s
def reset_parameters(self):
nn.init.xavier_uniform_(self.w.weight)
nn.init.xavier_uniform_(self.a1.weight)
nn.init.xavier_uniform_(self.a2.weight)
def forward(self, x, mask):
batch_size, seq_len, n_input = x.shape
# [batch_size, seq_len, n_inner]
h = self.w(x)
# [batch_size, seq_len, seq_len]
e = self.leaky_relu(self.a1(h) + self.a2(h).transpose(-1, -2))
e = e.masked_fill_(~mask, torch.finfo(torch.float).min)
# [batch_size, seq_len, seq_len]
attn = self.dropout(e.softmax(-1))
# [batch_size, seq_len, n_input]
h = torch.einsum('btt,bth->bth', attn, x)
return self.elu(h)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment