Skip to content

Instantly share code, notes, and snippets.

@BarclayII
Last active January 5, 2026 15:46
Show Gist options
  • Select an option

  • Save BarclayII/ba88c3101a1ab784c5f27f73e2088ec1 to your computer and use it in GitHub Desktop.

Select an option

Save BarclayII/ba88c3101a1ab784c5f27f73e2088ec1 to your computer and use it in GitHub Desktop.
PinSage example implementation
import pandas as pd
import dgl
import os
import torch
class MovieLens(object):
def __init__(self, directory):
'''
directory: path to movielens directory which should have the three
files:
users.dat
movies.dat
ratings.dat
'''
self.directory = directory
users = []
movies = []
ratings = []
# read users
with open(os.path.join(directory, 'users.dat')) as f:
for l in f:
id_, gender, age, occupation, zip_ = l.split('::')
users.append({
'id': int(id_),
'gender': gender,
'age': age,
'occupation': occupation,
'zip': zip_,
})
self.users = pd.DataFrame(users).set_index('id')
# read movies
with open(os.path.join(directory, 'movies.dat'), encoding='latin1') as f:
for l in f:
id_, title, genres = l.split('::')
genres_set = set(genres.split('|'))
data = {'id': int(id_), 'title': title}
for g in genres_set:
data[g] = True
movies.append(data)
self.movies = pd.DataFrame(movies).set_index('id')
# read ratings
with open(os.path.join(directory, 'ratings.dat')) as f:
for l in f:
user_id, movie_id, rating, timestamp = [int(_) for _ in l.split('::')]
ratings.append({
'user_id': user_id,
'movie_id': movie_id,
'rating': rating,
'timestamp': timestamp,
})
self.ratings = pd.DataFrame(ratings)
# randomly generate training-validation-test set on the ratings table
test_set = self.ratings.sample(frac=0.05, random_state=1).index
valid_set = self.ratings.sample(frac=0.05, random_state=2).index
valid_set = valid_set.difference(test_set)
self.ratings['valid'] = self.ratings.index.isin(valid_set)
self.ratings['test'] = self.ratings.index.isin(test_set)
def todglgraph(self):
'''
returns:
g, user_ids, movie_ids:
The DGL graph itself. Each edge has a binary feature "valid" and a binary
feature "test" indicating validation/test example.
The list of user IDs (node i corresponds to user user_ids[i])
The list of movie IDs (node i + len(user_ids) corresponds to movie movie_ids[i])
'''
user_ids = list(self.users.index)
movie_ids = list(self.movies.index)
user_ids_invmap = {id_: i for i, id_ in enumerate(user_ids)}
movie_ids_invmap = {id_: i for i, id_ in enumerate(movie_ids)}
g = dgl.DGLGraph()
g.add_nodes(len(user_ids) + len(movie_ids))
rating_user_vertices = [user_ids_invmap[id_] for id_ in self.ratings['user_id'].values]
rating_movie_vertices = [movie_ids_invmap[id_] + len(user_ids)
for id_ in self.ratings['movie_id'].values]
valid_tensor = torch.from_numpy(self.ratings['valid'].values.astype('uint8'))
test_tensor = torch.from_numpy(self.ratings['test'].values.astype('uint8'))
g.add_edges(rating_user_vertices,
rating_movie_vertices,
data={'valid': valid_tensor, 'test': test_tensor})
g.add_edges(rating_movie_vertices,
rating_user_vertices,
data={'valid': valid_tensor, 'test': test_tensor})
return g, user_ids, movie_ids
import torch
import torch.nn as nn
import torch.nn.functional as F
import dgl
def get_embeddings(h, nodeset):
return h[nodeset]
def put_embeddings(h, nodeset, new_embeddings):
n_nodes = nodeset.shape[0]
n_features = h.shape[1]
return h.scatter(0, nodeset[:, None].expand(n_nodes, n_features), new_embeddings)
def random_walk_sampler(G, nodeset, n_traces, n_hops):
'''
G: DGLGraph
nodeset: 1D CPU Tensor of node IDs
n_traces: int
n_hops: int
return: 3D CPU Tensor or node IDs (n_nodes, n_traces, n_hops + 1)
'''
n_nodes = nodeset.shape[0]
traces = torch.zeros(n_nodes, n_traces, n_hops + 1, dtype=torch.int64)
for i in range(n_nodes):
for j in range(n_traces):
cur = nodeset[i]
for k in range(n_hops + 1):
traces[i, j, k] = cur
neighbors = G.successors(cur)
assert neighbors.shape[0] > 0
cur = neighbors[torch.randint(len(neighbors), ())]
return traces
def random_walk_distribution(G, nodeset, n_traces, n_hops):
n_nodes = nodeset.shape[0]
n_available_nodes = G.number_of_nodes()
traces = random_walk_sampler(G, nodeset, n_traces, n_hops)
visited_nodes = traces[:, :, 1:].view(n_nodes, -1) # (n_nodes, n_visited_other_nodes)
visited_counts = (
torch.zeros(n_nodes, n_available_nodes)
.scatter_add_(1, visited_nodes, torch.ones_like(visited_nodes, dtype=torch.float64)))
visited_prob = visited_counts / visited_counts.sum(1, keepdim=True)
return visited_prob
def random_walk_distribution_topt(G, nodeset, n_traces, n_hops, top_T):
'''
returns the top T important neighbors of each node in nodeset, as well as
the weights of the neighbors.
'''
visited_prob = random_walk_distribution(G, nodeset, n_traces, n_hops)
return visited_prob.topk(1, top_T)
def random_walk_nodeflow(G, nodeset, n_layers, n_traces, n_hops, top_T):
'''
returns a list of triplets (
"active" node IDs whose embeddings are computed at the i-th layer (num_nodes,)
weight of each neighboring node of each "active" node on the i-th layer (num_nodes, top_T)
neighboring node IDs for each "active" node on the i-th layer (num_nodes, top_T)
)
'''
nodeflow = []
cur_nodeset = nodeset
for i in reversed(range(n_layers)):
nb_weights, nb_nodes = random_walk_distribution_topt(G, nodeset, n_traces, n_hops, top_T)
nodeflow.insert((cur_nodeset, nb_weights, nb_nodes))
cur_nodeset = torch.cat([nb_nodes.view(-1), cur_nodeset]).unique()
return nodeflow
class PinSageConv(nn.Module):
def __init__(self, in_features, out_features, hidden_features):
super(PinSageConv, self).__init__()
self.in_features = in_features
self.out_features = out_features
self.hidden_features = hidden_features
self.Q = nn.Linear(in_features, hidden_features)
self.W = nn.Linear(in_features + hidden_features, out_features)
def forward(self, h, nodeset, nb_nodes, nb_weights):
'''
h: node embeddings (num_total_nodes, in_features), or a container
of the node embeddings (for distributed computing)
nodeset: node IDs in this minibatch (num_nodes,)
nb_nodes: neighbor node IDs of each node in nodeset (num_nodes, num_neighbors)
nb_weights: weight of each neighbor node (num_nodes, num_neighbors)
return: new node embeddings (num_nodes, out_features)
'''
n_nodes, T = nb_nodes.shape[0]
h_nodeset = get_embeddings(h, nodeset) # (n_nodes, in_features)
h_neighbors = get_embeddings(h, nb_nodes.view(-1)).view(n_nodes, T, self.in_features)
h_neighbors = F.relu(self.Q(h_neighbors))
h_agg = (nb_weights[:, :, None] * h_neighbors).sum(1) / nb_weights.sum(1, keepdim=True)
h_concat = torch.cat([h_nodeset, h_agg], 1)
h_new = F.relu(self.W(h_concat))
h_new /= h_new.norm(dim=1, keepdim=True)
return h_new
class PinSage(nn.Module):
'''
Completes a multi-layer PinSage convolution
G: DGLGraph
feature_sizes: the dimensionality of input/hidden/output features
T: number of neighbors we pick for each node
n_traces: number of random walk traces to generate during sampling
n_hops: number of hops of each random walk trace during sampling
'''
def __init__(self, G, feature_sizes, T, n_traces, n_hops):
super(PinSage, self).__init__()
self.G = G
self.T = T
self.n_traces = n_traces
self.n_hops = n_hops
self.in_features = feature_sizes[0]
self.out_features = feature_sizes[-1]
self.n_layers = len(feature_sizes) - 1
self.convs = nn.ModuleList()
for i in range(self.n_layers):
self.convs.append(PinSageConv(
feature_sizes[i], feature_sizes[i+1], feature_sizes[i+1]))
def forward(self, h, nodeset):
'''
Given a complete embedding matrix h and a list of node IDs, return
the output embeddings of these node IDs.
h: node embeddings (num_total_nodes, in_features), or a container
of the node embeddings (for distributed computing)
nodeset: node IDs in this minibatch (num_nodes,)
return: new node embeddings (num_nodes, out_features)
'''
nodeflow = random_walk_nodeflow(self.G, nodeset, self.n_layers, self.n_traces, self.n_hops, self.T)
for i, (nodeset, nb_weights, nb_nodes) in enumerate(nodeflow):
new_embeddings = self.convs[i](h, nodeset, nb_nodes, nb_weights)
h = put_embeddings(h, nodeset, new_embeddings)
return get_embeddings(h, nodeset)
@LogicMatthew
Copy link

LogicMatthew commented Jan 5, 2026

I think in random_walk_nodeflow reversing is unnecessary and a bit misleading at first glimpse. It effectively acts as a plain loop counter here, so this would be clearer:

for _ in range(n_layers):

Second thing:

nb_weights, nb_nodes = random_walk_distribution_topt(G, nodeset, n_traces, n_hops, top_T)

Here nodeset should be cur_nodeset, because the set of active nodes grows at each layer and the sampling should follow the updated node set.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment