Last active
January 17, 2019 20:26
-
-
Save felixgwu/14636d7afce34f0bec78e96839b60889 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from lanczos_utils import check_dist | |
EPS = float(np.finfo(np.float32).eps) | |
__all__ = ['LanczosNet'] | |
class LanczosNet(nn.Module): | |
def __init__(self, nfeat, nhid, nclass, dropout=0.5, num_layer=2, num_eig_vec=20, | |
spectral_filter_kind='MLP', short_diffusion_dist=[1, 2, 5, 7], | |
long_diffusion_dist=[10, 20, 30]): | |
super(LanczosNet, self).__init__() | |
self.input_dim = nfeat | |
self.hidden_dim = nhid | |
self.output_dim = nclass | |
self.num_layer = num_layer | |
self.dropout = dropout | |
self.short_diffusion_dist = check_dist(short_diffusion_dist) | |
self.long_diffusion_dist = check_dist(long_diffusion_dist) | |
self.max_short_diffusion_dist = max( | |
self.short_diffusion_dist) if self.short_diffusion_dist else None | |
self.max_long_diffusion_dist = max( | |
self.long_diffusion_dist) if self.long_diffusion_dist else None | |
self.num_scale_short = len(self.short_diffusion_dist) | |
self.num_scale_long = len(self.long_diffusion_dist) | |
self.num_eig_vec = num_eig_vec | |
self.spectral_filter_kind = spectral_filter_kind | |
dim_list = [self.input_dim] + [self.hidden_dim] * (num_layer - 1) + [self.output_dim] | |
self.filter = nn.ModuleList([ | |
nn.Linear(dim_list[tt] * ( | |
self.num_scale_short + self.num_scale_long + 1), | |
dim_list[tt + 1]) for tt in range(self.num_layer) | |
]) | |
# spectral filters | |
if self.spectral_filter_kind == 'MLP' and self.num_scale_long > 0: | |
self.spectral_filter = nn.ModuleList([ | |
nn.Sequential(*[ | |
nn.Linear(self.num_scale_long, 128), | |
nn.ReLU(), | |
nn.Linear(128, self.num_scale_long) | |
]) for _ in range(self.num_layer) | |
]) | |
self._init_param() | |
def _init_param(self): | |
for ff in self.filter: | |
if isinstance(ff, nn.Linear): | |
nn.init.xavier_uniform_(ff.weight.data) | |
if ff.bias is not None: | |
ff.bias.data.zero_() | |
if self.spectral_filter_kind == 'MLP' and self.num_scale_long > 0: | |
for ff in self.spectral_filter: | |
for f in ff: | |
if isinstance(f, nn.Linear): | |
nn.init.xavier_uniform_(f.weight.data) | |
if f.bias is not None: | |
f.bias.data.zero_() | |
def _get_spectral_filters(self, T_list, Q, layer_idx): | |
""" Construct Spectral Filters based on Lanczos Outputs | |
Args: | |
T_list: each element is of shape B X K | |
Q: shape B X N X K | |
Returns: | |
L: shape B X N X N X num_scale | |
""" | |
# multi-scale diffusion | |
L = [] | |
# spectral filter | |
if self.spectral_filter_kind == 'MLP': | |
DD = torch.stack( | |
T_list, dim=2).view(Q.shape[0] * Q.shape[2], -1) # shape BK X D | |
DD = self.spectral_filter[layer_idx](DD).view(Q.shape[0], Q.shape[2], | |
-1) # shape B X K X D | |
for ii in range(self.num_scale_long): | |
tmp_DD = DD[:, :, ii].unsqueeze(1).repeat(1, Q.shape[1], | |
1) # shape B X N X K | |
L += [(Q * tmp_DD).bmm(Q.transpose(1, 2))] | |
else: | |
for ii in range(self.num_scale_long): | |
DD = T_list[ii].unsqueeze(1).repeat(1, Q.shape[1], 1) # shape B X N X K | |
L += [(Q * DD).bmm(Q.transpose(1, 2))] | |
return torch.stack(L, dim=3) | |
def forward(self, node_feat, L, D, V): | |
""" | |
shape parameters: | |
batch size = B | |
embedding dim = D | |
max number of nodes within one mini batch = N | |
number of edge types = E | |
number of predicted properties = P | |
number of approximated eigenvalues, i.e., Ritz values = K | |
Args: | |
node_feat: long tensor, shape B X N x D | |
L: float tensor, shape B X N X N X (E + 1) | |
D: float tensor, Ritz values, shape B X K | |
V: float tensor, Ritz vectors, shape B X N X K | |
label: float tensor, shape B X P | |
mask: float tensor, shape B X N | |
""" | |
if node_feat.dim() == 2: | |
squeeze = True | |
node_feat = node_feat.unsqueeze(0) | |
L = L.unsqueeze(0) | |
D = D.unsqueeze(0) | |
V = V.unsqueeze(0) | |
batch_size = node_feat.shape[0] | |
num_node = node_feat.shape[1] | |
D_pow_list = [] | |
for ii in self.long_diffusion_dist: | |
D_pow_list += [torch.pow(D, ii)] # shape B X K | |
########################################################################### | |
# Graph Convolution | |
########################################################################### | |
state = node_feat | |
# propagation | |
for tt in range(self.num_layer): | |
msg = [] | |
if self.num_scale_long > 0: | |
Lf = self._get_spectral_filters(D_pow_list, V, tt) | |
# short diffusion | |
if self.num_scale_short > 0: | |
tmp_state = state | |
for ii in range(1, self.max_short_diffusion_dist + 1): | |
tmp_state = torch.bmm(L, tmp_state) | |
if ii in self.short_diffusion_dist: | |
msg += [tmp_state] | |
# long diffusion | |
if self.num_scale_long > 0: | |
for ii in range(self.num_scale_long): | |
msg += [torch.bmm(Lf[:, :, :, ii], state)] # shape: B X N X D | |
msg += [torch.bmm(L, state)] # shape: B X N X D | |
msg = torch.cat(msg, dim=2).view(num_node * batch_size, -1) | |
state = F.relu(self.filter[tt](msg)).view(batch_size, num_node, -1) | |
state = F.dropout(state, self.dropout, training=self.training) | |
if squeeze: | |
state = state.squeeze(0) | |
return F.softmax(state, dim=-1) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment