Skip to content

Instantly share code, notes, and snippets.

@generall
Created April 27, 2019 22:32
Show Gist options
  • Save generall/68fddb87ae1845d6f54c958ed3d0addb to your computer and use it in GitHub Desktop.
Save generall/68fddb87ae1845d6f54c958ed3d0addb to your computer and use it in GitHub Desktop.
Shrinking Fasttext embeddings
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/vasnetsov/project/EntityCategoryPrediction/venv/lib/python3.6/site-packages/smart_open/ssh.py:34: UserWarning: paramiko missing, opening SSH/SCP/SFTP paths will be disabled. `pip install paramiko` to suppress\n",
" warnings.warn('paramiko missing, opening SSH/SCP/SFTP paths will be disabled. `pip install paramiko` to suppress')\n"
]
}
],
"source": [
"import os\n",
"import tqdm\n",
"import numpy as np\n",
"import gensim\n",
"\n",
"from collections import defaultdict\n",
"from gensim.models.utils_any2vec import ft_ngram_hashes # This function is used to calculate hashes from ngrams to determine position in ngram matrix"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Convert Fasttext model into Gensim format\n",
"\n",
"You may skip this step if model is already in Gensim format"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"ft = gensim.models.FastText.load_fasttext_format(\"../data/fasttext_embedding.bin\") # Original fasttext embeddings from https://fasttext.cc/\n",
"ft.wv.save('../data/fasttext_gensim.model') # we are not saving training weights to save space "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Loading fasttext model"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"ft = gensim.models.KeyedVectors.load(\"../data/fasttext_gensim.model\")"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"# Setup new size of embeddings\n",
"new_vocab_size = 250_000\n",
"new_ngrams_size = 1_000_000 # Should be GCD of initial "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### New vocab\n",
"Here we select the most frequent words in existing vocabulary"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"sorted_vocab = sorted(ft.vocab.items(), key=lambda x: x[1].count, reverse=True)\n",
"top_vocab = dict(sorted_vocab[:new_vocab_size])\n",
"\n",
"top_vector_ids = [x.index for x in top_vocab.values()]\n",
"assert max(top_vector_ids) < new_vocab_size # Assume vocabulary is already sorted by frequency\n",
"\n",
"top_vocab_vectors = ft.vectors_vocab[:new_vocab_size]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Ngrams remapping\n",
"\n",
"Ngram vectors are located by calculating hash of the ngram chars.\n",
"We need to calculate new hashes for smaller matrix and map old vectors to a new ones.\n",
"Since the size of the new matrix is smaller, there will be collisions.\n",
"We are going to resolv them by calculating weighted sum of vectors of collided ngrams."
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 2519370/2519370 [02:12<00:00, 19027.06it/s]\n"
]
}
],
"source": [
"new_to_old_buckets = defaultdict(set)\n",
"old_hash_count = defaultdict(int)\n",
"for word, vocab_word in tqdm.tqdm(ft.vocab.items()):\n",
" old_hashes = ft_ngram_hashes(word, ft.min_n, ft.max_n, ft.bucket, fb_compatible=ft.compatible_hash)\n",
" new_hashes = ft_ngram_hashes(word, ft.min_n, ft.max_n, new_ngrams_size, fb_compatible=ft.compatible_hash)\n",
" \n",
" for old_hash in old_hashes:\n",
" old_hash_count[old_hash] += 1 # calculate frequency of ngrams for proper weighting\n",
" \n",
" for old_hash, new_hash in zip(old_hashes, new_hashes):\n",
" new_to_old_buckets[new_hash].add(old_hash)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"\n",
"# Create new FastText model instance\n",
"new_ft = gensim.models.keyedvectors.FastTextKeyedVectors(\n",
" vector_size=ft.vector_size,\n",
" min_n=ft.min_n,\n",
" max_n=ft.max_n,\n",
" bucket=new_ngrams_size,\n",
" compatible_hash=ft.compatible_hash\n",
")\n",
"\n",
"# Set shrinked vocab and vocab vector\n",
"new_ft.vectors_vocab = top_vocab_vectors\n",
"new_ft.vectors = new_ft.vectors_vocab\n",
"new_ft.vocab = top_vocab\n",
"\n",
"# Set ngram vectors\n",
"new_ft.init_ngrams_weights(42) # Default random seed\n",
"for new_hash, old_buckets in new_to_old_buckets.items():\n",
" total_sum = sum(old_hash_count[old_hash] for old_hash in old_buckets)\n",
" \n",
" new_vector = np.zeros(ft.vector_size, dtype=np.float32)\n",
" for old_hash in old_buckets:\n",
" weight = old_hash_count[old_hash] / total_sum\n",
" new_vector += ft.vectors_ngrams[old_hash] * weight\n",
" \n",
" new_ft.vectors_ngrams[new_hash] = new_vector"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"new_ft.save('../data/shrinked_fasttext.model')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Calculating losses\n",
"\n",
"We may gain a view of accuracy losses by measuring similarity between original vocab vectors and new vectors recreated from shrink n-grams."
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"test_vocab_size = 100_000"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"test_vocab = sorted_vocab[new_vocab_size:new_vocab_size + test_vocab_size]"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
"sims = []\n",
"for test_word, _ in test_vocab:\n",
" sim = ft.cosine_similarities(ft.get_vector(test_word), [new_ft.get_vector(test_word)])\n",
" if not np.isnan(sim):\n",
" sims.append(sim)"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"size: 1000000 Similarity: 0.9471525\n"
]
}
],
"source": [
"print(f\"size: {new_ngrams_size}\", \"Similarity:\", np.sum(sims) / test_vocab_size)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.7"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment