Created
May 2, 2019 08:16
-
-
Save tyokota/c894dc825880d2dfdc23e164a9173915 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def load_embedding(embedding): | |
print(f'Loading {embedding} embedding..') | |
def get_coefs(word,*arr): return word, np.asarray(arr, dtype='float32') | |
if embedding == 'glove': | |
EMBEDDING_FILE = f'{FILE_DIR}/embeddings/glove.840B.300d/glove.840B.300d.txt' | |
embeddings_index = dict(get_coefs(*o.split(" ")) for o in open(EMBEDDING_FILE, encoding="utf8")) | |
elif embedding == 'wiki-news': | |
EMBEDDING_FILE = f'{FILE_DIR}/embeddings/wiki-news-300d-1M/wiki-news-300d-1M.vec' | |
embeddings_index = dict(get_coefs(*o.split(" ")) for o in open(EMBEDDING_FILE, encoding="utf8") if len(o)>100) | |
elif embedding == 'paragram': | |
EMBEDDING_FILE = f'{FILE_DIR}/embeddings/paragram_300_sl999/paragram_300_sl999.txt' | |
embeddings_index = dict(get_coefs(*o.split(" ")) for o in open(EMBEDDING_FILE, encoding="utf8", errors='ignore') if len(o)>100) | |
elif embedding == 'google-news': | |
from gensim.models import KeyedVectors | |
EMBEDDING_FILE = f'{FILE_DIR}/embeddings/GoogleNews-vectors-negative300/GoogleNews-vectors-negative300.bin' | |
embeddings_index = KeyedVectors.load_word2vec_format(EMBEDDING_FILE, binary=True) | |
return embeddings_index | |
embeddings_index_1 = load_embedding('glove') | |
embeddings_index_2 = load_embedding('wiki-news') | |
def build_embedding_matrix(embeddings_index_1, embeddings_index_2, lower=False, upper=False): | |
wl = WordNetLemmatizer().lemmatize | |
word_index = tokenizer.word_index | |
nb_words = min(num_words, len(word_index)) | |
embedding_matrix = np.zeros((nb_words, 601)) | |
something_1 = embeddings_index_1.get("something") | |
something_2 = embeddings_index_2.get("something") | |
something = np.zeros((601,)) | |
something[:300,] = something_1 | |
something[300:600,] = something_2 | |
something[600,] = 0 | |
def all_caps(word): | |
return len(word) > 1 and word.isupper() | |
hit, total = 0, 0 | |
def embed_word(embedding_matrix,i,word): | |
embedding_vector_1 = embeddings_index_1.get(word) | |
if embedding_vector_1 is not None: | |
if all_caps(word): | |
last_value = np.array([1]) | |
else: | |
last_value = np.array([0]) | |
embedding_matrix[i,:300] = embedding_vector_1 | |
embedding_matrix[i,600] = last_value | |
embedding_vector_2 = embeddings_index_2.get(word) | |
if embedding_vector_2 is not None: | |
embedding_matrix[i,300:600] = embedding_vector_2 | |
for word, i in word_index.items(): | |
if i >= num_words: continue | |
if embeddings_index_1.get(word) is not None: | |
embed_word(embedding_matrix,i,word) | |
hit += 1 | |
else: | |
if len(word) > 20: | |
embedding_matrix[i] = something | |
else: | |
word2 = wl(wl(word, pos='v'), pos='a') | |
if embeddings_index_1.get(word2) is not None: | |
embed_word(embedding_matrix,i,word2) | |
hit += 1 | |
else: | |
if len(word) < 3: continue | |
word2 = word.upper() | |
if embeddings_index_1.get(word2) is not None: | |
embed_word(embedding_matrix,i,word2) | |
hit += 1 | |
else: | |
word2 = word.upper() | |
word2 = wl(wl(word2, pos='v'), pos='a') | |
if embeddings_index_1.get(word2) is not None: | |
embed_word(embedding_matrix,i,word2) | |
hit += 1 | |
else: | |
embedding_matrix[i] = something | |
total += 1 | |
print("Matched Embeddings: found {} out of total {} words at a rate of {:.2f}%".format(hit, total, hit * 100.0 / total)) | |
return embedding_matrix | |
embedding_matrix = build_embedding_matrix(embeddings_index_1, embeddings_index_2, lower=True, upper=True) | |
del embeddings_index_1, embeddings_index_2 | |
gc.collect() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment