Skip to content

Instantly share code, notes, and snippets.

@eileen-code4fun
Last active January 7, 2022 20:04
Show Gist options
  • Save eileen-code4fun/40d809cbf1c16b918c0f5c1c8675085a to your computer and use it in GitHub Desktop.
Save eileen-code4fun/40d809cbf1c16b918c0f5c1c8675085a to your computer and use it in GitHub Desktop.
GraphSAGE Definition
from dgl.nn.tensorflow import SAGEConv
class SAGE(tf.keras.Model):
def __init__(self, feat_dim, hidden_dim, class_num):
super(SAGE, self).__init__()
self.h1 = SAGEConv(in_feats=feat_dim, out_feats=hidden_dim, aggregator_type='pool', feat_drop=0.5, activation=tf.nn.relu)
self.h2 = SAGEConv(in_feats=feat_dim, out_feats=class_num, aggregator_type='pool')
def call(self, g, features):
h = features
h = self.h1(g, h)
h = self.h2(g, h)
return h
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment