Skip to content

Instantly share code, notes, and snippets.

@eileen-code4fun
Created January 7, 2022 20:04
Show Gist options
  • Save eileen-code4fun/c58d8b5541047b796f2c41939b610639 to your computer and use it in GitHub Desktop.
Save eileen-code4fun/c58d8b5541047b796f2c41939b610639 to your computer and use it in GitHub Desktop.
GIN Definition
from dgl.nn.tensorflow import GINConv
class GIN(tf.keras.Model):
def mlp(feat_dim, hidden_dim, out_dim):
m = tf.keras.models.Sequential()
m.add(tf.keras.layers.Input(shape=(feat_dim,)))
m.add(tf.keras.layers.Dense(hidden_dim, activation='relu'))
m.add(tf.keras.layers.Dropout(0.5))
m.add(tf.keras.layers.Dense(out_dim))
return m
def __init__(self, feat_dim, hidden_dim, class_num):
super(GIN, self).__init__()
self.h1 = GINConv(apply_func=mlp(feat_dim, hidden_dim, hidden_dim), aggregator_type='sum', learn_eps=True)
self.h2 = GINConv(apply_func=mlp(hidden_dim, hidden_dim, class_num), aggregator_type='sum', learn_eps=True)
def call(self, g, features):
h = features
h = self.h1(g, h)
h = self.h2(g, h)
return h
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment