Created
December 3, 2020 09:03
-
-
Save yzh119/cc1ab6ad284af8b219bc94568c2f37e6 to your computer and use it in GitHub Desktop.
Efficient Sparse Transformer implementation with DGL's builtin operators
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import dgl | |
import dgl.ops as ops | |
import numpy as np | |
import torch as th | |
import torch.nn as nn | |
class FFN(nn.Module): | |
def __init__(self, d_feat, d_ffn, dropout=0.1): | |
super().__init__() | |
self.linear_0 = nn.Linear(d_feat, d_ffn) | |
self.linear_1 = nn.Linear(d_ffn, d_feat) | |
self.dropout = nn.Dropout(dropout) | |
def forward(self, x): | |
return self.linear_1(self.dropout(th.relu(self.linear_0(x)))) | |
class TransLayer(nn.Module): | |
def __init__(self, | |
d_feat=512, | |
h=8, | |
d_qkv=64, | |
d_ffn=2048, | |
dropout_h=0.1, | |
dropout_a=0.1): | |
""" | |
Parameters | |
---------- | |
d_feat : int | |
The feature dimension, defaults to 512. | |
h : int | |
The number of heads, defaults to 8. | |
d_qkv : int | |
The query/key/value dimension per head, defaults to 64. | |
d_ffn : int | |
The inner feature dimension in FFN layer, defaults to 2048. | |
dropout_h : float | |
The dropout rate on feature, defaults to 0.1. | |
dropout_a : float | |
The dropout rate on attention value, defaults to 0.1. | |
""" | |
super().__init__() | |
self.d_feat = d_feat | |
self.h = h | |
self.d_qkv = d_qkv | |
self.d_ffn = d_ffn | |
self.proj_q = nn.Linear(d_feat, h * d_qkv, bias=False) | |
self.proj_k = nn.Linear(d_feat, h * d_qkv, bias=False) | |
self.proj_v = nn.Linear(d_feat, h * d_qkv, bias=False) | |
self.proj_o = nn.Linear(h * d_qkv, d_feat, bias=False) | |
self.norm_in = nn.LayerNorm(self.d_feat) | |
self.norm_inter = nn.LayerNorm(self.d_feat) | |
self.drop_h = nn.Dropout(dropout_h) | |
self.drop_att = nn.Dropout(dropout_a) | |
self.ffn = FFN(d_feat, d_ffn) | |
def forward(self, g, h): | |
"""Forward function of Transformer on a DGLGraph. | |
Parameters | |
---------- | |
g : DGLGraph | |
The graph to apply Transformer self-attention on. | |
h : Tensor | |
Input node feature. | |
""" | |
q = self.proj_q(h).view(-1, self.h, self.d_qkv) # (N, h, d_qkv) | |
k = self.proj_k(h).view(-1, self.h, self.d_qkv) # (N, h, d_qkv) | |
v = self.proj_v(h).view(-1, self.h, self.d_qkv) # (N, h, d_qkv) | |
e = ops.u_dot_v(g, k, q) # (E, h, 1): dot product of query and key | |
# Notice: we can also implement relative positional encoding efficiently with | |
# operators such as u_dot_e, e_dot_v | |
a = self.drop_att(ops.edge_softmax(g, e / np.sqrt(self.d_qkv))) # (E, h, 1): attention score | |
wv = ops.u_mul_e_sum(g, v, a) # (N, h, d_qkv): weighted sum of value by attention score | |
o = self.drop_h(self.proj_o(wv.view(-1, self.h * self.d_qkv))) # (N, d_feat): output | |
h = self.norm_in(h + o) | |
h = self.norm_inter(h + self.ffn(h)) | |
return h | |
if __name__ == "__main__": | |
layer = TransLayer() | |
g = dgl.rand_graph(30, 100) | |
h = th.rand(30, 512) | |
# g = g.to(0) | |
# h = h.to(0) | |
print(layer(g, h)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment