Last active
March 1, 2023 09:37
-
-
Save soaxelbrooke/f75ff2a6f8a432afaf72990bcfe6f086 to your computer and use it in GitHub Desktop.
Script for converting txt word embedding files to SQLite databases for fast embedding lookup.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3.6 | |
""" | |
Example usage: | |
$ python3.6 wvsqlite.py glove.840B.300d.txt | |
Produces an sqlite database at with byte strings of floats for each word vector, indexed by | |
token for fast lookup for vocabs much smaller than the embedding vocab (aka most real vocabs). | |
Float size can be set via FLOAT_BYTES env var, and can be 4 or 8, and LIMIT can be set to take | |
the top N word vectors. | |
Metadata is also saved in the `vector_meta` table. | |
""" | |
import pandas | |
import csv | |
import sqlite3 | |
from tqdm import tqdm | |
import sys | |
import os | |
def guess_embed_dim(embeddings_path: str) -> int: | |
with open(embeddings_path) as infile: | |
return max([len(next(infile).split(' ')), len(next(infile).split(' '))]) | |
def load_wvs(embeddings_path: str, embedding_dim: int, limit=None): | |
if limit is not None: | |
limit = int(limit) | |
with open(embeddings_path) as infile: | |
if next(infile).split(' ') == embeddings_path: | |
# Skip header for fasttext, don't for glove | |
infile.seek(0) | |
return pandas.read_csv(infile, header=None, delim_whitespace=True, | |
names=list(range(embedding_dim)), quoting=csv.QUOTE_NONE, | |
nrows=limit, index_col=0) | |
def insert_wvs(wvs: pandas.DataFrame, embedding_dim: int, float_bytes: int): | |
assert float_bytes == 4 or float_bytes == 8 | |
try: | |
os.remove('vectors.sqlite') | |
except: | |
pass | |
conn = sqlite3.connect('vectors.sqlite') | |
conn.execute(''' | |
CREATE TABLE vector_meta ( | |
vector_float_bytes integer, | |
embedding_dimensions integer, | |
vocab_size integer | |
) | |
''') | |
conn.execute('INSERT INTO vector_meta VALUES (?, ?, ?)', | |
(float_bytes, embedding_dim, wvs.shape[0])) | |
conn.execute('CREATE TABLE vectors (token text primary key, vector_bytes blob);') | |
seen = set() | |
for token, series in tqdm(wvs.iterrows(), total=wvs.shape[0]): | |
if token in seen: | |
continue | |
float_type = 'float32' if float_bytes == 4 else 'float64' | |
vec_bytes = series.values.astype(float_type).tobytes() | |
conn.execute('INSERT INTO vectors VALUES (?, ?)', (token, vec_bytes)) | |
seen.add(token) | |
conn.commit() | |
if __name__ == '__main__': | |
embed_path = sys.argv[1] | |
embed_dim = guess_embed_dim(embed_path) | |
float_bytes = int(os.getenv('FLOAT_BYTES', 4)) | |
print('Loading word vectors...') | |
wvs = load_wvs(embed_path, embed_dim, os.getenv('LIMIT')) | |
print('Saving word vector bytes to vectors.sqlite...') | |
insert_wvs(wvs, embed_dim, float_bytes) | |
print('Done!') | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment