Skip to content

Instantly share code, notes, and snippets.

@manisnesan
Created October 29, 2024 16:38
Show Gist options
  • Save manisnesan/11a29bc2bdf681bc927d20da752d6b64 to your computer and use it in GitHub Desktop.
Save manisnesan/11a29bc2bdf681bc927d20da752d6b64 to your computer and use it in GitHub Desktop.
import requests
UNIFIED_FIELDS = ["title", "text"]
CASE_FIELDS = ["summary", "product", "description"]
SOLUTION_FIELDS = [
"title", "solution_environment", "issue",
"solution_rootcause", "solution_diagnosticsteps",
"solution_resolution"
]
def join_fields(doc, fields):
return " ".join(doc[field] for field in fields if field in doc)
def trim_string(doc):
if "summary" in doc:
return join_fields(doc, CASE_FIELDS)
elif "title" in doc:
return join_fields(doc, SOLUTION_FIELDS)
elif "title" in doc:
return join_fields(doc, UNIFIED_FIELDS)
else:
raise ValueError("Unknown document type with fields: " + ", ".join(doc.keys()))
def rerank(server, port, query, documents, top_k=10):
base_url = f"http://{server}:{port}/api/v1/senttransformer/task/rerank"
if not documents:
print("Reranking cannot be performed for null or empty documents")
return []
query_text = trim_string(query)
doc_array = [
{
"document": {
"text": trim_string(doc),
"title": doc.get("title"),
"url": doc.get("uri")
}
}
for doc in documents[:top_k]
]
payload = {
"inputs": {
"queries": [query_text],
"documents": {"documents": doc_array},
"top_n": top_k
}
}
response = requests.post(base_url, json=payload)
response.raise_for_status() # Raise an error for bad responses
results = response.json().get("results", [])
return [
(documents[entry["corpus_id"]]["uri"], entry["score"])
for entry in results[:top_k]
]
# Example usage:
# Sample data for demonstration
server_address = "example.com" # Replace with your server address
port = 8080 # Replace with your port number
# Sample query and documents
query = {
"summary": "Need help with installation issues",
"product": "Red Hat Enterprise Linux",
"description": "User is facing problems during installation."
}
documents = [
{
"title": "Installation Guide for RHEL",
"uri": "http://example.com/docs/rhel-installation",
"summary": "A comprehensive guide to install RHEL.",
"product": "Red Hat Enterprise Linux",
"description": "This document provides step-by-step instructions."
},
{
"title": "Troubleshooting RHEL Installation",
"uri": "http://example.com/docs/rhel-troubleshooting",
"summary": "Common issues and solutions during RHEL installation.",
"product": "Red Hat Enterprise Linux",
"description": "This document helps troubleshoot installation problems."
},
{
"title": "RHEL Support",
"uri": "http://example.com/docs/rhel-support",
"summary": "Support options for RHEL users.",
"product": "Red Hat Enterprise Linux",
"description": "Information on how to get support for RHEL."
}
]
# Call the rerank function
try:
results = rerank(server_address, port, query, documents, top_k=10)
print("Reranked Results:")
for uri, score in results:
print(f"Document URI: {uri}, Score: {score}")
except Exception as e:
print(f"An error occurred: {e}")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment