Created
July 11, 2019 06:25
-
-
Save Dumbris/a573df979fb65fe9cca2f9f1396e2233 to your computer and use it in GitHub Desktop.
Universal Sentence Encoder + nmslib for kNN search. Minimal example.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import tensorflow as tf | |
| import tensorflow_hub as hub | |
| import numpy as np | |
| import pandas as pd | |
| import nmslib | |
| docs = \ | |
| """good beer | |
| sweet bready malt noticeable | |
| good one share friends nice statement | |
| overall nice beer delectably subtle use rum expecting bit smoke | |
| fact one bit gusher popped cap | |
| slight fruit tastes found much | |
| others style tend least make taste better | |
| score lagerpilsner aromas always 3 unless actually detect real smell beer gets normal 3 | |
| average bar beer average bars | |
| sweet enjoy malty yeasty mouthfeel oily slick | |
| molasses yes | |
| fun excited see schlafly try something new | |
| taste tart berry reminiscent champagne | |
| pours clear honey color little bit haze glass yeast bottom bottle | |
| mouthfeel like inside mythical leather pants | |
| overall think might hop aversion stage right didnt enjoy much hoping | |
| puckery aftertaste | |
| roasted malt vanilla smoke | |
| damn good imperial stout | |
| leaves lace""".split("\n") | |
| queries = [ | |
| "vanilla", | |
| "fruit", | |
| ] | |
| batch_size = 100 | |
| module_url = "https://tfhub.dev/google/universal-sentence-encoder-large/3" | |
| tf.logging.set_verbosity(tf.logging.ERROR) | |
| def embed_useT(module): | |
| with tf.Graph().as_default(): | |
| sentences = tf.placeholder(tf.string) | |
| embed = hub.Module(module) | |
| embeddings = embed(sentences) | |
| session = tf.train.MonitoredSession() | |
| return lambda x: session.run(embeddings, {sentences: x}) | |
| embed_fn = embed_useT(module_url) | |
| def get_emb(arr): | |
| return np.concatenate([embed_fn(arr[i:i + batch_size]) for i in range(0, len(arr), batch_size)]) | |
| def build_index(X): | |
| nmslib_index = nmslib.init(space='l2', method='sw-graph') | |
| #nmslib_index = nmslib.init(method='sw-graph', space='angulardist') | |
| #nmslib_index = nmslib.init(method='hnsw', space='l1') | |
| nmslib_index.addDataPointBatch(ids=np.arange(X.shape[0], dtype=np.int32), data=X) | |
| nmslib_index.createIndex({}, print_progress=True) | |
| return nmslib_index | |
| def search(X_query, X_text_query, X_text, nmslib_index): | |
| neighbours = nmslib_index.knnQueryBatch(X_query, k=3, num_threads=2) | |
| c = 0 | |
| results = [] | |
| for items, dists in neighbours: | |
| query_id = c | |
| for item, dist in zip(items, dists): | |
| item = int(item) | |
| results.append({ | |
| "query_1": X_text_query[query_id], | |
| "text_1": X_text[item], | |
| "query_id": query_id, | |
| "dist": dist, | |
| }) | |
| c += 1 | |
| return results | |
| ###Main | |
| X_docs = get_emb(docs) | |
| X_queries = get_emb(queries) | |
| nmslib_index = build_index(X_docs) | |
| pd.DataFrame(search(X_queries, queries, docs, nmslib_index)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment