Last active
May 25, 2023 21:24
-
-
Save avidale/4e90c78987abc097502ba442a909cec2 to your computer and use it in GitHub Desktop.
Compress fasttext model by applying denser ngram hash
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# based on Andrey Vasnetsov code: https://gist.github.com/generall/68fddb87ae1845d6f54c958ed3d0addb | |
import os | |
import numpy as np | |
import gensim | |
from collections import defaultdict | |
from copy import deepcopy | |
from gensim.models.utils_any2vec import ft_ngram_hashes | |
from tqdm.auto import tqdm, trange | |
def compress_fasttext(ft, new_vocab_size=1000, new_ngrams_size = 20_000): | |
# just make sure the model is consistent | |
# set ft.vectors[w] = (ft.vectors_vocab[w] + ft.ngrams(w)) / (len(ft.grams(w)) + 1) | |
ft.adjust_vectors() | |
# select the vocabulary to keep | |
sorted_vocab = sorted(ft.vocab.items(), key=lambda x: x[1].count, reverse=True) | |
top_vocab = dict(deepcopy(sorted_vocab[:new_vocab_size])) | |
top_vector_ids = [] | |
for new_index, vocab_item in enumerate(top_vocab.values()): | |
top_vector_ids.append(vocab_item.index) | |
vocab_item.index = new_index | |
top_vector_ids = [x.index for x in top_vocab.values()] | |
top_vectors_vocab = ft.vectors_vocab[top_vector_ids, :] | |
top_vectors = ft.vectors[top_vector_ids, :] | |
# remap the ngrams | |
new_to_old_buckets = defaultdict(set) | |
old_hash_count = defaultdict(int) | |
for word, vocab_word in tqdm(ft.vocab.items()): | |
old_hashes = ft_ngram_hashes(word, ft.min_n, ft.max_n, ft.bucket, fb_compatible=ft.compatible_hash) | |
new_hashes = ft_ngram_hashes(word, ft.min_n, ft.max_n, new_ngrams_size, fb_compatible=ft.compatible_hash) | |
for old_hash in old_hashes: | |
old_hash_count[old_hash] += 1 # calculate frequency of ngrams for proper weighting | |
for old_hash, new_hash in zip(old_hashes, new_hashes): | |
new_to_old_buckets[new_hash].add(old_hash) | |
# Create new FastText model instance | |
new_ft = gensim.models.keyedvectors.FastTextKeyedVectors( | |
vector_size=ft.vector_size, | |
min_n=ft.min_n, | |
max_n=ft.max_n, | |
bucket=new_ngrams_size, | |
compatible_hash=ft.compatible_hash | |
) | |
new_ft.init_ngrams_weights(42) # Default random seed | |
# Set shrinked vocab and vocab vector | |
new_ft.vectors_vocab = None # if we don't fine tune the model we don't need these vectors | |
new_ft.vectors = top_vectors | |
new_ft.vocab = top_vocab | |
# Set ngram vectors | |
for new_hash, old_buckets in tqdm(new_to_old_buckets.items()): | |
total_sum = sum(old_hash_count[old_hash] for old_hash in old_buckets) | |
new_vector = np.zeros(ft.vector_size, dtype=np.float32) | |
for old_hash in old_buckets: | |
weight = old_hash_count[old_hash] / total_sum | |
new_vector += ft.vectors_ngrams[old_hash] * weight | |
new_ft.vectors_ngrams[new_hash] = new_vector | |
return new_ft |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment