Skip to content

Instantly share code, notes, and snippets.

@Dumbris
Created July 11, 2019 06:25
Show Gist options
  • Select an option

  • Save Dumbris/a573df979fb65fe9cca2f9f1396e2233 to your computer and use it in GitHub Desktop.

Select an option

Save Dumbris/a573df979fb65fe9cca2f9f1396e2233 to your computer and use it in GitHub Desktop.
Universal Sentence Encoder + nmslib for kNN search. Minimal example.
import tensorflow as tf
import tensorflow_hub as hub
import numpy as np
import pandas as pd
import nmslib
docs = \
"""good beer
sweet bready malt noticeable
good one share friends nice statement
overall nice beer delectably subtle use rum expecting bit smoke
fact one bit gusher popped cap
slight fruit tastes found much
others style tend least make taste better
score lagerpilsner aromas always 3 unless actually detect real smell beer gets normal 3
average bar beer average bars
sweet enjoy malty yeasty mouthfeel oily slick
molasses yes
fun excited see schlafly try something new
taste tart berry reminiscent champagne
pours clear honey color little bit haze glass yeast bottom bottle
mouthfeel like inside mythical leather pants
overall think might hop aversion stage right didnt enjoy much hoping
puckery aftertaste
roasted malt vanilla smoke
damn good imperial stout
leaves lace""".split("\n")
queries = [
"vanilla",
"fruit",
]
batch_size = 100
module_url = "https://tfhub.dev/google/universal-sentence-encoder-large/3"
tf.logging.set_verbosity(tf.logging.ERROR)
def embed_useT(module):
with tf.Graph().as_default():
sentences = tf.placeholder(tf.string)
embed = hub.Module(module)
embeddings = embed(sentences)
session = tf.train.MonitoredSession()
return lambda x: session.run(embeddings, {sentences: x})
embed_fn = embed_useT(module_url)
def get_emb(arr):
return np.concatenate([embed_fn(arr[i:i + batch_size]) for i in range(0, len(arr), batch_size)])
def build_index(X):
nmslib_index = nmslib.init(space='l2', method='sw-graph')
#nmslib_index = nmslib.init(method='sw-graph', space='angulardist')
#nmslib_index = nmslib.init(method='hnsw', space='l1')
nmslib_index.addDataPointBatch(ids=np.arange(X.shape[0], dtype=np.int32), data=X)
nmslib_index.createIndex({}, print_progress=True)
return nmslib_index
def search(X_query, X_text_query, X_text, nmslib_index):
neighbours = nmslib_index.knnQueryBatch(X_query, k=3, num_threads=2)
c = 0
results = []
for items, dists in neighbours:
query_id = c
for item, dist in zip(items, dists):
item = int(item)
results.append({
"query_1": X_text_query[query_id],
"text_1": X_text[item],
"query_id": query_id,
"dist": dist,
})
c += 1
return results
###Main
X_docs = get_emb(docs)
X_queries = get_emb(queries)
nmslib_index = build_index(X_docs)
pd.DataFrame(search(X_queries, queries, docs, nmslib_index))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment