Skip to content

Instantly share code, notes, and snippets.

@maxidl
Created April 22, 2019 20:44
Show Gist options
  • Save maxidl/da5c1a9bc0f3c41310514b478cc7a183 to your computer and use it in GitHub Desktop.
Save maxidl/da5c1a9bc0f3c41310514b478cc7a183 to your computer and use it in GitHub Desktop.
Minimalist deepwalk implementation
import numpy as np
import pandas as pd
from scipy import sparse
from pathlib import Path
from joblib import Parallel, delayed
from itertools import chain
from random import shuffle
from tqdm import tqdm
from gensim.models import Word2Vec
import logging
def sample_random_walks(adj_matrix, walks_per_node, walk_length, workers):
def random_walk(start_node, walk_length):
path = [start_node]
for _ in range(walk_length):
curr_node = path[-1]
neighbors = adj_matrix.getrow(curr_node).indices
if len(neighbors) == 0:
break
path.append(np.random.choice(neighbors))
return list(map(str, path))
def walks_from_node(node):
walks = []
for _ in range(walks_per_node):
walks.append(random_walk(node, walk_length))
return walks
walks = Parallel(n_jobs=workers)(delayed(walks_from_node)(node) for node in tqdm(set(adj_matrix.indices)))
return list(chain.from_iterable(walks))
logging.basicConfig(format="%(levelname)s - %(asctime)s: %(message)s", datefmt='%H:%M:%S', level=logging.INFO)
cwd = Path('.')
logging.info('reading edgelist ...')
edgelist = pd.read_csv(cwd / 'wikipedia' / 'anchor_graph' / 'wiki_entities_anchor_edges.tsv', sep=' ', header=None)
logging.info(f'graph containing {edgelist.shape[0]} edges')
src_nodes, dest_nodes = edgelist.iloc[:, 0], edgelist.iloc[:, 1]
weights = np.ones(edgelist.shape[0], dtype=np.int64)
src_nodes, dest_nodes, weights = np.concatenate((src_nodes, dest_nodes)), np.concatenate(
(dest_nodes, src_nodes)), np.concatenate((weights, weights)) # undirected graph
adj_matrix = sparse.coo_matrix((weights, (src_nodes, dest_nodes)), shape=(edgelist.shape[0], edgelist.shape[0])).tocsr()
# set hyperparams
embed_dim = 128
window_size = 10
neg_samples = 5
epochs = 5
walk_length = 40
walks_per_node = 80
workers = 20
logging.info(f'sampling random walks using {workers} threads')
walks = sample_random_walks(adj_matrix, walks_per_node, walk_length, workers)
shuffle(walks)
word2vec = Word2Vec(sentences=walks, min_count=0, size=embed_dim, window=window_size, sg=1, hs=0, negative=neg_samples,
workers=workers, sorted_vocab=0, iter=epochs)
word2vec.wv.save_word2vec_format(cwd / 'entity_graph_embeddings.bin', binary=True)
word2vec.wv.save_word2vec_format(cwd / 'entity_graph_embeddings.txt')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment