Skip to content

Instantly share code, notes, and snippets.

@philschmid
Last active June 2, 2021 11:47
Show Gist options
  • Select an option

  • Save philschmid/9b2f01d09ed0fa969d87b3e779c11654 to your computer and use it in GitHub Desktop.

Select an option

Save philschmid/9b2f01d09ed0fa969d87b3e779c11654 to your computer and use it in GitHub Desktop.
Semantic Search: showing how to ingest SQAuD dataset into elastic with infinity + an example on how to query an index
from argparse import ArgumentParser
from asyncio import get_event_loop, gather
from multiprocessing import Process
from requests import post, Session
from re import compile
from datasets import load_dataset
from elasticsearch import Elasticsearch
from elasticsearch.helpers import bulk
from tqdm import tqdm
CREDENTIALS = "elastic:XXXXXX"
sentence_re = compile(r"(.*?)\.")
API_URL = "https:/{URL}/api/{TASK}"
def process_squad(dataset, idx):
with Session() as session:
for item in tqdm(dataset, desc="Inserting SQuAD documents...", position=idx):
idx, context = item["id"], item["context"]
yield {
"_op_type": "create",
"_index": "documents",
"_id": idx,
"context": context
}
sentences = [sentence.group(1).strip() + "." for sentence in sentence_re.finditer(context)]
for current_sent, next_sent in zip(sentences[0: -1], sentences[1:]):
sentence = " ".join((current_sent, next_sent))
response = session.post(API_URL, json={"input": sentence})
if response.status_code == 200:
content = response.json()
yield {
"_op_type": "create",
"_index": "sentences",
"doc_id": idx,
"text": sentence,
"embedding": content["vector"]
}
else:
print(f"Non 200 response: {response}")
def process_entry(worker_id, workers):
es_ = Elasticsearch()
squad_ = load_dataset("squad", split="validation").shard(workers, worker_id, contiguous=True, keep_in_memory=True)
bulk(es_, process_squad(squad_, worker_id))
def index_squad(es: Elasticsearch, workers: int):
if not es.indices.exists("documents"):
print("Creating ElasticSearch index \"documents\"")
res = es.indices.create(index="documents", body={
"mappings": {
"properties": {
"context": {
"type": "text",
"index": "false"
}
}
}
})
if not es.indices.exists("sentences"):
print("Creating ElasticSearch index \"sentences\"")
res = es.indices.create(index="sentences", body={
"mappings": {
"properties": {
"doc_id": {
"type": "keyword",
"index": "false"
},
"text": {
"type": "text",
"index": "false"
},
"embedding": {
"type": "dense_vector",
"dims": 768
}
}
}
})
processes = []
for worker_id in range(workers):
process = Process(target=process_entry, args=(worker_id, workers))
process.start()
processes.append(process)
for process in processes:
process.join()
if __name__ == '__main__':
parser = ArgumentParser()
parser.add_argument("--workers", type=int, default=1, help="Number of worker")
args = parser.parse_args()
es = Elasticsearch()
index_squad(es, args.workers)
from asyncio import get_event_loop
from aiohttp import ClientSession, TCPConnector
from json import dumps
from elasticsearch import AsyncElasticsearch, Elasticsearch
from requests import Session
CREDENTIALS = "elastic:XXXXXX"
API_URL = "https:/{URL}/api/{TASK}"
def query_elastic():
es = Elasticsearch()
while True:
query = input("Please enter your query (exit to quit): ")
# query = "To whom did the Virgin Mary allegedly appear in 1858 in Lourdes France ?"
# query = "Superbowl"
if query.lower() == "exit":
break
try:
with Session() as session:
response = session.post(API_URL, json={"input": query})
if response.status_code == 200:
content = response.json()
sentences = es.search(index="sentences", body={
"query": {
"script_score": {
"query": {"match_all": {}},
"script": {
"source": "cosineSimilarity(params.query_vector, 'embedding') + 1",
"params": {
"query_vector": content["vector"]
}
}
}
}
})
for hit in sentences["hits"]["hits"]:
print(f"Score: {hit['_score']} => {hit['_source']['text']}")
else:
print(f"Got error: {response.content}")
except Exception as e:
print(e)
if __name__ == '__main__':
query_elastic()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment