Created
February 19, 2014 12:37
-
-
Save moustaki/9091097 to your computer and use it in GitHub Desktop.
Faster save-load for word2vec
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import dbm, os | |
import cPickle as pickle | |
from gensim.models import Word2Vec | |
import numpy as np | |
def save_model(model, directory): | |
model.init_sims() # making sure syn0norm is initialised | |
if not os.path.exists(directory): | |
os.makedirs(directory) | |
# Saving indexes as DBM'ed dictionary | |
word_to_index = dbm.open(os.path.join(directory, 'word_to_index'), 'n') | |
index_to_word = dbm.open(os.path.join(directory, 'index_to_word'), 'n') | |
for key in model.vocab.keys(): | |
word_to_index[key] = pickle.dumps(model.vocab[key]) | |
index_to_word[str(model.vocab[key].index)] = key | |
word_to_index.close() | |
index_to_word.close() | |
# Memory-mapping normalised word vectors | |
syn0norm_m = np.memmap(os.path.join(directory, 'syn0norm.dat'), dtype='float32', mode='w+', shape=model.syn0norm.shape) | |
syn0norm_m[:] = model.syn0norm[:] | |
syn0norm_m.flush() | |
# And pickling model object, witout data | |
vocab, syn0norm, syn0, index2word = model.vocab, model.syn0norm, model.syn0, model.index2word | |
model.vocab, model.syn0norm, model.syn0, model.index2word = None, None, None, None | |
model_f = open(os.path.join(directory, 'model.pickle'), 'w') | |
pickle.dump(model, model_f) | |
model_f.close() | |
model.vocab, model.syn0norm, model.syn0, model.index2word = vocab, syn0norm, syn0, index2word | |
def load_model(directory): | |
model = pickle.load(open(os.path.join(directory, 'model.pickle'))) | |
model.vocab = DBMPickledDict(os.path.join(directory, 'word_to_index')) | |
model.index2word = DBMPickledDict(os.path.join(directory, 'index_to_word')) | |
model.syn0norm = np.memmap(os.path.join(directory, 'syn0norm.dat'), dtype='float32', mode='r', shape=(len(model.vocab.keys()), model.layer1_size)) | |
model.syn0 = model.syn0norm | |
return model | |
class DBMPickledDict(dict): | |
def __init__(self, dbm_file): | |
self._dbm = dbm.open(dbm_file, 'r') | |
def __setitem__(self, key, value): | |
raise Exception("Read-only vocabulary") | |
def __delitem__(self, key): | |
raise Exception("Read-only vocabulary") | |
def __iter__(self): | |
return iter(self._dbm.keys()) | |
def __len__(self): | |
return len(self._dbm) | |
def __contains__(self, key): | |
if isinstance(key, int): | |
key = str(key) | |
return key in self._dbm | |
def __getitem__(self, key): | |
if isinstance(key, int): | |
key = str(key) | |
return self._dbm[key] | |
else: | |
return pickle.loads(self._dbm[key]) | |
def keys(self): | |
return self._dbm.keys() | |
def values(self): | |
return [self._dbm[key] for key in self._dbm.keys()] | |
def itervalues(self): | |
return (self._dbm[key] for key in self._dbm.keys()) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment