Created
October 29, 2024 16:38
-
-
Save manisnesan/11a29bc2bdf681bc927d20da752d6b64 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import requests | |
UNIFIED_FIELDS = ["title", "text"] | |
CASE_FIELDS = ["summary", "product", "description"] | |
SOLUTION_FIELDS = [ | |
"title", "solution_environment", "issue", | |
"solution_rootcause", "solution_diagnosticsteps", | |
"solution_resolution" | |
] | |
def join_fields(doc, fields): | |
return " ".join(doc[field] for field in fields if field in doc) | |
def trim_string(doc): | |
if "summary" in doc: | |
return join_fields(doc, CASE_FIELDS) | |
elif "title" in doc: | |
return join_fields(doc, SOLUTION_FIELDS) | |
elif "title" in doc: | |
return join_fields(doc, UNIFIED_FIELDS) | |
else: | |
raise ValueError("Unknown document type with fields: " + ", ".join(doc.keys())) | |
def rerank(server, port, query, documents, top_k=10): | |
base_url = f"http://{server}:{port}/api/v1/senttransformer/task/rerank" | |
if not documents: | |
print("Reranking cannot be performed for null or empty documents") | |
return [] | |
query_text = trim_string(query) | |
doc_array = [ | |
{ | |
"document": { | |
"text": trim_string(doc), | |
"title": doc.get("title"), | |
"url": doc.get("uri") | |
} | |
} | |
for doc in documents[:top_k] | |
] | |
payload = { | |
"inputs": { | |
"queries": [query_text], | |
"documents": {"documents": doc_array}, | |
"top_n": top_k | |
} | |
} | |
response = requests.post(base_url, json=payload) | |
response.raise_for_status() # Raise an error for bad responses | |
results = response.json().get("results", []) | |
return [ | |
(documents[entry["corpus_id"]]["uri"], entry["score"]) | |
for entry in results[:top_k] | |
] | |
# Example usage: | |
# Sample data for demonstration | |
server_address = "example.com" # Replace with your server address | |
port = 8080 # Replace with your port number | |
# Sample query and documents | |
query = { | |
"summary": "Need help with installation issues", | |
"product": "Red Hat Enterprise Linux", | |
"description": "User is facing problems during installation." | |
} | |
documents = [ | |
{ | |
"title": "Installation Guide for RHEL", | |
"uri": "http://example.com/docs/rhel-installation", | |
"summary": "A comprehensive guide to install RHEL.", | |
"product": "Red Hat Enterprise Linux", | |
"description": "This document provides step-by-step instructions." | |
}, | |
{ | |
"title": "Troubleshooting RHEL Installation", | |
"uri": "http://example.com/docs/rhel-troubleshooting", | |
"summary": "Common issues and solutions during RHEL installation.", | |
"product": "Red Hat Enterprise Linux", | |
"description": "This document helps troubleshoot installation problems." | |
}, | |
{ | |
"title": "RHEL Support", | |
"uri": "http://example.com/docs/rhel-support", | |
"summary": "Support options for RHEL users.", | |
"product": "Red Hat Enterprise Linux", | |
"description": "Information on how to get support for RHEL." | |
} | |
] | |
# Call the rerank function | |
try: | |
results = rerank(server_address, port, query, documents, top_k=10) | |
print("Reranked Results:") | |
for uri, score in results: | |
print(f"Document URI: {uri}, Score: {score}") | |
except Exception as e: | |
print(f"An error occurred: {e}") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment