Last active
June 15, 2022 16:11
-
-
Save tkoz0/6707cc2dde34d64797e8a9a6c0f6d135 to your computer and use it in GitHub Desktop.
semantle solver for https://semantle.novalis.org/
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import struct | |
import numpy | |
import tqdm | |
import random | |
dataset_file = 'GoogleNews-vectors-negative300.bin' | |
output_file = 'dataset-trimmed.bin' | |
word_filter = lambda z : z.isalpha() and z.lower() == z | |
outf = open(output_file,'wb') | |
# open dataset file and read first line | |
f = open(dataset_file,'rb') | |
line = '' | |
while True: | |
c = f.read(1).decode() | |
assert len(c) == 1 | |
line += c | |
if c == '\n': | |
break | |
numwords,numdims = map(int,line.split()) # first line | |
def get_word(): # read a word in the dataset file | |
word = b'' | |
while True: | |
c = f.read(1) | |
assert len(c) == 1 | |
word += c | |
if c == b' ': | |
return word | |
word2vec = dict() # map word to numpy vectors | |
print('Loading dataset...') | |
for _ in tqdm.tqdm(range(numwords)): | |
word = get_word() | |
floatbin = f.read(numdims*4) # little endian | |
if word_filter(word.decode()[:-1]): | |
word2vec[word] = floatbin | |
assert f.read(1) == b'' | |
f.close() | |
outf.write(('%d %d\n'%(len(word2vec),numdims)).encode()) | |
print('Writing trimmed dataset...') | |
for word in tqdm.tqdm(word2vec): | |
outf.write(word) | |
outf.write(word2vec[word]) | |
outf.close() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
<!DOCTYPE html> | |
<html> | |
<head> | |
<title>semantle solver</title> | |
</head> | |
<body> | |
<h1>semantle solver</h1> | |
<table> | |
<tr> | |
<th>word</th> | |
<th>similarity</th> | |
</tr> | |
<tr> | |
<td><input type="text" id="word0" /></td> | |
<td><input type="text" id="sim0" /></td> | |
</tr> | |
<tr> | |
<td><input type="text" id="word1" /></td> | |
<td><input type="text" id="sim1" /></td> | |
</tr> | |
<tr> | |
<td><input type="text" id="word2" /></td> | |
<td><input type="text" id="sim2" /></td> | |
</tr> | |
<tr> | |
<td><input type="text" id="word3" /></td> | |
<td><input type="text" id="sim3" /></td> | |
</tr> | |
<tr> | |
<td><input type="text" id="word4" /></td> | |
<td><input type="text" id="sim4" /></td> | |
</tr> | |
<tr> | |
<td><input type="text" id="word5" /></td> | |
<td><input type="text" id="sim5" /></td> | |
</tr> | |
<tr> | |
<td><input type="text" id="word6" /></td> | |
<td><input type="text" id="sim6" /></td> | |
</tr> | |
<tr> | |
<td><input type="text" id="word7" /></td> | |
<td><input type="text" id="sim7" /></td> | |
</tr> | |
<tr> | |
<td><input type="text" id="word8" /></td> | |
<td><input type="text" id="sim8" /></td> | |
</tr> | |
<tr> | |
<td><input type="text" id="word9" /></td> | |
<td><input type="text" id="sim9" /></td> | |
</tr> | |
</table> | |
<button onclick="submit2()">submit</button> | |
result: <span id="result"></span> | |
<script> | |
function submit2() { | |
let data = []; | |
for (let i = 0; i < 10; ++i) { | |
data.push([document.getElementById("word"+i).value, | |
document.getElementById("sim"+i).value]); | |
} | |
fetch("/post/data/here", { | |
method: "POST", | |
headers: {'Content-Type': 'application/json'}, | |
body: JSON.stringify(data) | |
}).then(async response => { | |
let text = await response.text() | |
document.getElementById("result").innerHTML = text; | |
console.log(text); | |
}); | |
} | |
</script> | |
</body> | |
</html> |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import struct | |
import numpy | |
import tqdm | |
import random | |
import bz2 | |
#dataset_file = 'GoogleNews-vectors-negative300.bin' | |
dataset_file = 'dataset-trimmed.bin.bz2' | |
word_filter = lambda z : z.isalpha() and z.lower() == z | |
# open dataset file and read first line | |
f = bz2.BZ2File(dataset_file) | |
#f = open(dataset_file,'rb') | |
line = '' | |
while True: | |
c = f.read(1).decode() | |
assert len(c) == 1 | |
line += c | |
if c == '\n': | |
break | |
numwords,numdims = map(int,line.split()) # first line | |
def get_word(): # read a word in the dataset file | |
word = b'' | |
while True: | |
c = f.read(1) | |
assert len(c) == 1 | |
word += c | |
if c == b' ': | |
return word[:-1].decode() | |
word2vec = dict() # map word to numpy vectors | |
print('Loading dataset...') | |
for _ in tqdm.tqdm(range(numwords)): | |
word = get_word() | |
floatbin = f.read(numdims*4) # little endian | |
if word_filter(word): | |
floats = [struct.unpack('<f',floatbin[4*i:4*i+4])[0] | |
for i in range(numdims)] | |
word2vec[word] = numpy.array(floats,dtype=float) | |
assert f.read(1) == b'' | |
f.close() | |
# cache vector norms | |
word2norm = dict() | |
print('Caching vector norms...') | |
for word in tqdm.tqdm(word2vec): | |
norm = numpy.linalg.norm(word2vec[word]) | |
word2norm[word] = norm | |
# find words in a given similarity range, limit = -1 for no limit | |
def near_words(words,word,sim_lo=0.5,sim_hi=1.0,limit=100,progress=True): | |
assert -1.0 <= sim_lo <= sim_hi <= 1.0 | |
result = [] | |
if progress: | |
print('Computing near words...') | |
iter_obj = tqdm.tqdm(words) | |
else: | |
iter_obj = words | |
for w in iter_obj: | |
dot = numpy.dot(word2vec[word],word2vec[w]) | |
sim = dot/(word2norm[word]*word2norm[w]) | |
if sim_lo <= sim <= sim_hi: | |
result.append((w,sim)) | |
if len(result) == limit: | |
break | |
return result | |
# print near words | |
def show_near_words(word,sim_lo=0.5,sim_hi=1.0,limit=100): | |
words = near_words(word,sim_hi,sim_lo,limit) | |
words = sorted(words,key=lambda x:-x[1]) | |
for w in words: | |
print('%10.6f %s'%(w[1],w[0])) | |
# compute intersection with multiple word data | |
# input list of (word,sim) pairs, use sim as displayed on semantle | |
def solve_semantle(data): | |
words = set(word2vec.keys()) | |
is_first = True | |
for word,sim in data: | |
sim = sim/100.0 | |
if is_first: | |
print('Finding words for first intersection...') | |
near_data = near_words(words,word,sim-0.001,sim+0.001,-1,is_first) | |
is_first = False | |
words &= set(n[0] for n in near_data) | |
return words | |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from http.server import BaseHTTPRequestHandler, HTTPServer | |
import logging | |
import json | |
import semantle_solver as ss | |
page = open('index.html','rb').read() | |
class S(BaseHTTPRequestHandler): | |
def _set_response(self): | |
self.send_response(200) | |
self.send_header('Content-type', 'text/html') | |
self.end_headers() | |
def do_GET(self): | |
#logging.info("GET request,\nPath: %s\nHeaders:\n%s\n", str(self.path), str(self.headers)) | |
self._set_response() | |
#self.wfile.write("GET request for {}".format(self.path).encode('utf-8')) | |
self.wfile.write(page) | |
def do_POST(self): | |
content_length = int(self.headers['Content-Length']) # <--- Gets the size of data | |
post_data = self.rfile.read(content_length) # <--- Gets the data itself | |
try: | |
jsondata = json.loads(post_data) | |
words = ss.solve_semantle([[j[0],float(j[1])] for j in jsondata if j[0] != ""]) | |
if len(words) > 100: | |
response = '> 100 words' | |
else: | |
response = str(words) | |
except Exception as e: | |
response = 'error: %s'%str(e) | |
#logging.info("POST request,\nPath: %s\nHeaders:\n%s\n\nBody:\n%s\n", | |
# str(self.path), str(self.headers), post_data.decode('utf-8')) | |
self._set_response() | |
self.wfile.write(response.encode()) | |
#self.wfile.write("POST request for {}".format(self.path).encode('utf-8')) | |
def run(server_class=HTTPServer, handler_class=S, port=8080): | |
logging.basicConfig(level=logging.INFO) | |
server_address = ('', port) | |
httpd = server_class(server_address, handler_class) | |
logging.info('Starting httpd...\n') | |
try: | |
httpd.serve_forever() | |
except KeyboardInterrupt: | |
pass | |
httpd.server_close() | |
logging.info('Stopping httpd...\n') | |
if __name__ == '__main__': | |
from sys import argv | |
if len(argv) == 2: | |
run(port=int(argv[1])) | |
else: | |
run() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
The server requires the dataset file (about 66MB) which I have uploaded to:
https://tkoz.ml/static/gist/6707cc2dde34d64797e8a9a6c0f6d135/dataset-trimmed.bin.bz2
It's possible I will reorganize this in the future and the link will become outdated.