Skip to content

Instantly share code, notes, and snippets.

@mynameisfiber
Created January 9, 2018 16:05
Show Gist options
  • Save mynameisfiber/960ccae07daa2d891df9f88bfd7e3fbe to your computer and use it in GitHub Desktop.
Save mynameisfiber/960ccae07daa2d891df9f88bfd7e3fbe to your computer and use it in GitHub Desktop.
h5py file cache for word2vec model
import h5py
import os
import pickle
try:
import gensim
except ImportError:
gensim = None
class Word2VecLookup(object):
def __init__(self, dbpath):
self.h5file = os.path.join(dbpath, "db.h5py")
self.lookupfile = os.path.join(dbpath, "lookup.pkl")
if not (os.path.exists(self.h5file) and
os.path.exists(self.lookupfile)):
print("Word2VecLookup directory is malformed. Please recreate "
"using Word2VecLookup.create_db")
raise TypeError
with open(self.lookupfile, 'rb') as fd:
self.lookup = pickle.load(fd)
def __getitem__(self, items):
if isinstance(items, (str, bytes)):
return self.__getitem__([items])
w2v_indicies = list(
filter(
None,
map(
self.lookup.get,
items
)
)
)
w2v_indicies_sort = sorted(set(w2v_indicies))
with h5py.File(self.h5file, 'r') as f:
vectors = f['word2vec'][w2v_indicies_sort]
unsort = {w: i for i, w in enumerate(w2v_indicies_sort)}
unsort_idxs = [unsort[i]for i in w2v_indicies if i in unsort]
return vectors[unsort_idxs, ...]
@staticmethod
def create_db(word2vec_bin, dbpath):
if gensim is None:
print("Cannot create h5db from word2vec binary file "
"without gensim installed")
model = gensim.models.KeyedVectors.load_word2vec_format(
word2vec_bin,
binary=True
)
os.makedirs(dbpath, exist_ok=True)
h5file = os.path.join(dbpath, "db.h5py")
lookupfile = os.path.join(dbpath, "lookup.pkl")
lookup = {w: d.index for w, d in model.vocab.items()}
with open(lookupfile, 'wb+') as fd:
pickle.dump(lookup, fd)
with h5py.File(h5file, 'w') as f:
f.create_dataset("word2vec", data=model.syn0)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment