Skip to content

Instantly share code, notes, and snippets.

@gaphex
Last active June 20, 2019 12:41
Show Gist options
  • Save gaphex/14e4ab4fdbdae7bef4846e50a4c2c4a1 to your computer and use it in GitHub Desktop.
Save gaphex/14e4ab4fdbdae7bef4846e50a4c2c4a1 to your computer and use it in GitHub Desktop.
Nearest Neighbour retriever
class TFRanker:
def __init__(self, dim, metric, top_k=3):
self.dim = dim
self.top_k = top_k
self.metric = metric
self.graph = tf.Graph()
self.session = tf.Session(graph=self.graph)
self.build_graph()
def build_graph(self):
with self.graph.as_default():
self.query = tf.placeholder("float", [self.dim])
self.kbase = tf.placeholder("float", [None, self.dim])
distance = self.metric(self.kbase, self.query)
top_neg_dists, top_indices = tf.math.top_k(tf.negative(distance), k=self.top_k)
top_dists = tf.negative(top_neg_dists)
self.top_distances = top_dists
self.top_indices = top_indices
def predict(self, kbase, query):
with self.graph.as_default():
I, D = self.session.run([self.top_indices, self.top_distances],
feed_dict={self.query: query, self.kbase: kbase})
return I, D
def euclidean_distance(kbase, query):
sqr_distance = tf.reduce_sum(
tf.pow(kbase - query, 2),
reduction_indices=1)
distance = tf.sqrt(sqr_distance)
return distance
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment