Created
April 22, 2019 20:44
-
-
Save maxidl/da5c1a9bc0f3c41310514b478cc7a183 to your computer and use it in GitHub Desktop.
Minimalist deepwalk implementation
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import pandas as pd | |
from scipy import sparse | |
from pathlib import Path | |
from joblib import Parallel, delayed | |
from itertools import chain | |
from random import shuffle | |
from tqdm import tqdm | |
from gensim.models import Word2Vec | |
import logging | |
def sample_random_walks(adj_matrix, walks_per_node, walk_length, workers): | |
def random_walk(start_node, walk_length): | |
path = [start_node] | |
for _ in range(walk_length): | |
curr_node = path[-1] | |
neighbors = adj_matrix.getrow(curr_node).indices | |
if len(neighbors) == 0: | |
break | |
path.append(np.random.choice(neighbors)) | |
return list(map(str, path)) | |
def walks_from_node(node): | |
walks = [] | |
for _ in range(walks_per_node): | |
walks.append(random_walk(node, walk_length)) | |
return walks | |
walks = Parallel(n_jobs=workers)(delayed(walks_from_node)(node) for node in tqdm(set(adj_matrix.indices))) | |
return list(chain.from_iterable(walks)) | |
logging.basicConfig(format="%(levelname)s - %(asctime)s: %(message)s", datefmt='%H:%M:%S', level=logging.INFO) | |
cwd = Path('.') | |
logging.info('reading edgelist ...') | |
edgelist = pd.read_csv(cwd / 'wikipedia' / 'anchor_graph' / 'wiki_entities_anchor_edges.tsv', sep=' ', header=None) | |
logging.info(f'graph containing {edgelist.shape[0]} edges') | |
src_nodes, dest_nodes = edgelist.iloc[:, 0], edgelist.iloc[:, 1] | |
weights = np.ones(edgelist.shape[0], dtype=np.int64) | |
src_nodes, dest_nodes, weights = np.concatenate((src_nodes, dest_nodes)), np.concatenate( | |
(dest_nodes, src_nodes)), np.concatenate((weights, weights)) # undirected graph | |
adj_matrix = sparse.coo_matrix((weights, (src_nodes, dest_nodes)), shape=(edgelist.shape[0], edgelist.shape[0])).tocsr() | |
# set hyperparams | |
embed_dim = 128 | |
window_size = 10 | |
neg_samples = 5 | |
epochs = 5 | |
walk_length = 40 | |
walks_per_node = 80 | |
workers = 20 | |
logging.info(f'sampling random walks using {workers} threads') | |
walks = sample_random_walks(adj_matrix, walks_per_node, walk_length, workers) | |
shuffle(walks) | |
word2vec = Word2Vec(sentences=walks, min_count=0, size=embed_dim, window=window_size, sg=1, hs=0, negative=neg_samples, | |
workers=workers, sorted_vocab=0, iter=epochs) | |
word2vec.wv.save_word2vec_format(cwd / 'entity_graph_embeddings.bin', binary=True) | |
word2vec.wv.save_word2vec_format(cwd / 'entity_graph_embeddings.txt') |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment