Created
June 7, 2021 13:20
-
-
Save cuter44/8abc91dbc4d6ae1f01a23e2f43d736cd to your computer and use it in GitHub Desktop.
Embedding conversion for NNQLM
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#! python3 | |
""" | |
USAGE | |
%0 n_voc n_dim INPUT_FILE OUTPUT_FILE | |
convert embedding from text format to google c format, as this program needed | |
for input. | |
n_voc, n_dim : Dimension of embedding table, neglected if these two values | |
presented in INPUT_FILE | |
INPUT_FILE : Embedding table in text format. May or may not contains a row | |
of two int, indicating n_voc and n_dim. Following with n_voc | |
rows, each row consists of a word and n_dim floats, | |
seperated by white space char. | |
OUTPUT_FILE : Default to stdout | |
""" | |
import sys | |
import numpy | |
if __name__ == "__main__": | |
n_voc = int(sys.argv[1]) | |
n_dim = int(sys.argv[2]) | |
fn_in = sys.argv[3] | |
fn_out = sys.argv[4] if len(sys.argv)>4 else None | |
fin = open(fn_in, "r", encoding="utf-8") | |
fout = open(fn_out, "wb") if fn_out else sys.stdout | |
try: | |
line = fin.readline().strip() | |
i_voc, i_dim = [int(e) for e in line.split()] | |
fout.write( | |
(str(i_voc)+' '+str(i_dim)+'\n').encode() | |
) | |
n_voc, n_dim = i_voc, i_dim | |
except ValueError as e: | |
# if no header | |
#print(e) | |
fin.close() | |
fin = open(fn_in, "r", encoding="utf-8") | |
fout.write( | |
(str(n_voc)+' '+str(n_dim)+'\n').encode() | |
) | |
# end header | |
cnt_line = 0 | |
for line in fin: | |
cnt_line += 1 | |
crumbs = line.split() | |
fout.write(crumbs[0].encode()) | |
fout.write(' '.encode()) | |
v = [float(e) for e in crumbs[1:]] | |
v = numpy.array(v, dtype=numpy.float32) | |
b = v.tobytes() | |
fout.write(b) | |
fout.write('\n'.encode()) | |
if v.shape[0] != n_dim: | |
print("(!) inconsistent vector dimension", cnt_line, crumbs[0]) | |
# end loop content | |
fin.close() | |
if fout!=sys.stdout: | |
fout.close() | |
if cnt_line!=n_voc: | |
print("(!) inconsistent vocabulary size", cnt_line) | |
# end main | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment