Last active
June 2, 2021 11:47
-
-
Save philschmid/9b2f01d09ed0fa969d87b3e779c11654 to your computer and use it in GitHub Desktop.
Semantic Search: showing how to ingest SQAuD dataset into elastic with infinity + an example on how to query an index
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| from argparse import ArgumentParser | |
| from asyncio import get_event_loop, gather | |
| from multiprocessing import Process | |
| from requests import post, Session | |
| from re import compile | |
| from datasets import load_dataset | |
| from elasticsearch import Elasticsearch | |
| from elasticsearch.helpers import bulk | |
| from tqdm import tqdm | |
| CREDENTIALS = "elastic:XXXXXX" | |
| sentence_re = compile(r"(.*?)\.") | |
| API_URL = "https:/{URL}/api/{TASK}" | |
| def process_squad(dataset, idx): | |
| with Session() as session: | |
| for item in tqdm(dataset, desc="Inserting SQuAD documents...", position=idx): | |
| idx, context = item["id"], item["context"] | |
| yield { | |
| "_op_type": "create", | |
| "_index": "documents", | |
| "_id": idx, | |
| "context": context | |
| } | |
| sentences = [sentence.group(1).strip() + "." for sentence in sentence_re.finditer(context)] | |
| for current_sent, next_sent in zip(sentences[0: -1], sentences[1:]): | |
| sentence = " ".join((current_sent, next_sent)) | |
| response = session.post(API_URL, json={"input": sentence}) | |
| if response.status_code == 200: | |
| content = response.json() | |
| yield { | |
| "_op_type": "create", | |
| "_index": "sentences", | |
| "doc_id": idx, | |
| "text": sentence, | |
| "embedding": content["vector"] | |
| } | |
| else: | |
| print(f"Non 200 response: {response}") | |
| def process_entry(worker_id, workers): | |
| es_ = Elasticsearch() | |
| squad_ = load_dataset("squad", split="validation").shard(workers, worker_id, contiguous=True, keep_in_memory=True) | |
| bulk(es_, process_squad(squad_, worker_id)) | |
| def index_squad(es: Elasticsearch, workers: int): | |
| if not es.indices.exists("documents"): | |
| print("Creating ElasticSearch index \"documents\"") | |
| res = es.indices.create(index="documents", body={ | |
| "mappings": { | |
| "properties": { | |
| "context": { | |
| "type": "text", | |
| "index": "false" | |
| } | |
| } | |
| } | |
| }) | |
| if not es.indices.exists("sentences"): | |
| print("Creating ElasticSearch index \"sentences\"") | |
| res = es.indices.create(index="sentences", body={ | |
| "mappings": { | |
| "properties": { | |
| "doc_id": { | |
| "type": "keyword", | |
| "index": "false" | |
| }, | |
| "text": { | |
| "type": "text", | |
| "index": "false" | |
| }, | |
| "embedding": { | |
| "type": "dense_vector", | |
| "dims": 768 | |
| } | |
| } | |
| } | |
| }) | |
| processes = [] | |
| for worker_id in range(workers): | |
| process = Process(target=process_entry, args=(worker_id, workers)) | |
| process.start() | |
| processes.append(process) | |
| for process in processes: | |
| process.join() | |
| if __name__ == '__main__': | |
| parser = ArgumentParser() | |
| parser.add_argument("--workers", type=int, default=1, help="Number of worker") | |
| args = parser.parse_args() | |
| es = Elasticsearch() | |
| index_squad(es, args.workers) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| from asyncio import get_event_loop | |
| from aiohttp import ClientSession, TCPConnector | |
| from json import dumps | |
| from elasticsearch import AsyncElasticsearch, Elasticsearch | |
| from requests import Session | |
| CREDENTIALS = "elastic:XXXXXX" | |
| API_URL = "https:/{URL}/api/{TASK}" | |
| def query_elastic(): | |
| es = Elasticsearch() | |
| while True: | |
| query = input("Please enter your query (exit to quit): ") | |
| # query = "To whom did the Virgin Mary allegedly appear in 1858 in Lourdes France ?" | |
| # query = "Superbowl" | |
| if query.lower() == "exit": | |
| break | |
| try: | |
| with Session() as session: | |
| response = session.post(API_URL, json={"input": query}) | |
| if response.status_code == 200: | |
| content = response.json() | |
| sentences = es.search(index="sentences", body={ | |
| "query": { | |
| "script_score": { | |
| "query": {"match_all": {}}, | |
| "script": { | |
| "source": "cosineSimilarity(params.query_vector, 'embedding') + 1", | |
| "params": { | |
| "query_vector": content["vector"] | |
| } | |
| } | |
| } | |
| } | |
| }) | |
| for hit in sentences["hits"]["hits"]: | |
| print(f"Score: {hit['_score']} => {hit['_source']['text']}") | |
| else: | |
| print(f"Got error: {response.content}") | |
| except Exception as e: | |
| print(e) | |
| if __name__ == '__main__': | |
| query_elastic() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment