Created
April 27, 2019 22:32
-
-
Save generall/68fddb87ae1845d6f54c958ed3d0addb to your computer and use it in GitHub Desktop.
Shrinking Fasttext embeddings
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"/home/vasnetsov/project/EntityCategoryPrediction/venv/lib/python3.6/site-packages/smart_open/ssh.py:34: UserWarning: paramiko missing, opening SSH/SCP/SFTP paths will be disabled. `pip install paramiko` to suppress\n", | |
" warnings.warn('paramiko missing, opening SSH/SCP/SFTP paths will be disabled. `pip install paramiko` to suppress')\n" | |
] | |
} | |
], | |
"source": [ | |
"import os\n", | |
"import tqdm\n", | |
"import numpy as np\n", | |
"import gensim\n", | |
"\n", | |
"from collections import defaultdict\n", | |
"from gensim.models.utils_any2vec import ft_ngram_hashes # This function is used to calculate hashes from ngrams to determine position in ngram matrix" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"### Convert Fasttext model into Gensim format\n", | |
"\n", | |
"You may skip this step if model is already in Gensim format" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"ft = gensim.models.FastText.load_fasttext_format(\"../data/fasttext_embedding.bin\") # Original fasttext embeddings from https://fasttext.cc/\n", | |
"ft.wv.save('../data/fasttext_gensim.model') # we are not saving training weights to save space " | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"### Loading fasttext model" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"ft = gensim.models.KeyedVectors.load(\"../data/fasttext_gensim.model\")" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# Setup new size of embeddings\n", | |
"new_vocab_size = 250_000\n", | |
"new_ngrams_size = 1_000_000 # Should be GCD of initial " | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"### New vocab\n", | |
"Here we select the most frequent words in existing vocabulary" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"sorted_vocab = sorted(ft.vocab.items(), key=lambda x: x[1].count, reverse=True)\n", | |
"top_vocab = dict(sorted_vocab[:new_vocab_size])\n", | |
"\n", | |
"top_vector_ids = [x.index for x in top_vocab.values()]\n", | |
"assert max(top_vector_ids) < new_vocab_size # Assume vocabulary is already sorted by frequency\n", | |
"\n", | |
"top_vocab_vectors = ft.vectors_vocab[:new_vocab_size]" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"### Ngrams remapping\n", | |
"\n", | |
"Ngram vectors are located by calculating hash of the ngram chars.\n", | |
"We need to calculate new hashes for smaller matrix and map old vectors to a new ones.\n", | |
"Since the size of the new matrix is smaller, there will be collisions.\n", | |
"We are going to resolv them by calculating weighted sum of vectors of collided ngrams." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"100%|██████████| 2519370/2519370 [02:12<00:00, 19027.06it/s]\n" | |
] | |
} | |
], | |
"source": [ | |
"new_to_old_buckets = defaultdict(set)\n", | |
"old_hash_count = defaultdict(int)\n", | |
"for word, vocab_word in tqdm.tqdm(ft.vocab.items()):\n", | |
" old_hashes = ft_ngram_hashes(word, ft.min_n, ft.max_n, ft.bucket, fb_compatible=ft.compatible_hash)\n", | |
" new_hashes = ft_ngram_hashes(word, ft.min_n, ft.max_n, new_ngrams_size, fb_compatible=ft.compatible_hash)\n", | |
" \n", | |
" for old_hash in old_hashes:\n", | |
" old_hash_count[old_hash] += 1 # calculate frequency of ngrams for proper weighting\n", | |
" \n", | |
" for old_hash, new_hash in zip(old_hashes, new_hashes):\n", | |
" new_to_old_buckets[new_hash].add(old_hash)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"\n", | |
"# Create new FastText model instance\n", | |
"new_ft = gensim.models.keyedvectors.FastTextKeyedVectors(\n", | |
" vector_size=ft.vector_size,\n", | |
" min_n=ft.min_n,\n", | |
" max_n=ft.max_n,\n", | |
" bucket=new_ngrams_size,\n", | |
" compatible_hash=ft.compatible_hash\n", | |
")\n", | |
"\n", | |
"# Set shrinked vocab and vocab vector\n", | |
"new_ft.vectors_vocab = top_vocab_vectors\n", | |
"new_ft.vectors = new_ft.vectors_vocab\n", | |
"new_ft.vocab = top_vocab\n", | |
"\n", | |
"# Set ngram vectors\n", | |
"new_ft.init_ngrams_weights(42) # Default random seed\n", | |
"for new_hash, old_buckets in new_to_old_buckets.items():\n", | |
" total_sum = sum(old_hash_count[old_hash] for old_hash in old_buckets)\n", | |
" \n", | |
" new_vector = np.zeros(ft.vector_size, dtype=np.float32)\n", | |
" for old_hash in old_buckets:\n", | |
" weight = old_hash_count[old_hash] / total_sum\n", | |
" new_vector += ft.vectors_ngrams[old_hash] * weight\n", | |
" \n", | |
" new_ft.vectors_ngrams[new_hash] = new_vector" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 8, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"new_ft.save('../data/shrinked_fasttext.model')" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"### Calculating losses\n", | |
"\n", | |
"We may gain a view of accuracy losses by measuring similarity between original vocab vectors and new vectors recreated from shrink n-grams." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 9, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"test_vocab_size = 100_000" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 10, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"test_vocab = sorted_vocab[new_vocab_size:new_vocab_size + test_vocab_size]" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 11, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"sims = []\n", | |
"for test_word, _ in test_vocab:\n", | |
" sim = ft.cosine_similarities(ft.get_vector(test_word), [new_ft.get_vector(test_word)])\n", | |
" if not np.isnan(sim):\n", | |
" sims.append(sim)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 15, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"size: 1000000 Similarity: 0.9471525\n" | |
] | |
} | |
], | |
"source": [ | |
"print(f\"size: {new_ngrams_size}\", \"Similarity:\", np.sum(sims) / test_vocab_size)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.6.7" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment