Created
January 9, 2018 16:05
-
-
Save mynameisfiber/960ccae07daa2d891df9f88bfd7e3fbe to your computer and use it in GitHub Desktop.
h5py file cache for word2vec model
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import h5py | |
import os | |
import pickle | |
try: | |
import gensim | |
except ImportError: | |
gensim = None | |
class Word2VecLookup(object): | |
def __init__(self, dbpath): | |
self.h5file = os.path.join(dbpath, "db.h5py") | |
self.lookupfile = os.path.join(dbpath, "lookup.pkl") | |
if not (os.path.exists(self.h5file) and | |
os.path.exists(self.lookupfile)): | |
print("Word2VecLookup directory is malformed. Please recreate " | |
"using Word2VecLookup.create_db") | |
raise TypeError | |
with open(self.lookupfile, 'rb') as fd: | |
self.lookup = pickle.load(fd) | |
def __getitem__(self, items): | |
if isinstance(items, (str, bytes)): | |
return self.__getitem__([items]) | |
w2v_indicies = list( | |
filter( | |
None, | |
map( | |
self.lookup.get, | |
items | |
) | |
) | |
) | |
w2v_indicies_sort = sorted(set(w2v_indicies)) | |
with h5py.File(self.h5file, 'r') as f: | |
vectors = f['word2vec'][w2v_indicies_sort] | |
unsort = {w: i for i, w in enumerate(w2v_indicies_sort)} | |
unsort_idxs = [unsort[i]for i in w2v_indicies if i in unsort] | |
return vectors[unsort_idxs, ...] | |
@staticmethod | |
def create_db(word2vec_bin, dbpath): | |
if gensim is None: | |
print("Cannot create h5db from word2vec binary file " | |
"without gensim installed") | |
model = gensim.models.KeyedVectors.load_word2vec_format( | |
word2vec_bin, | |
binary=True | |
) | |
os.makedirs(dbpath, exist_ok=True) | |
h5file = os.path.join(dbpath, "db.h5py") | |
lookupfile = os.path.join(dbpath, "lookup.pkl") | |
lookup = {w: d.index for w, d in model.vocab.items()} | |
with open(lookupfile, 'wb+') as fd: | |
pickle.dump(lookup, fd) | |
with h5py.File(h5file, 'w') as f: | |
f.create_dataset("word2vec", data=model.syn0) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment