Skip to content

Instantly share code, notes, and snippets.

@cyberpunk042
Last active August 16, 2024 21:57
Show Gist options
  • Save cyberpunk042/97ea9a5aec60dda798494ca8ab4747bd to your computer and use it in GitHub Desktop.
Save cyberpunk042/97ea9a5aec60dda798494ca8ab4747bd to your computer and use it in GitHub Desktop.
Snippet to query haystack-ai using a prompt and a local or remote file. The file can be knowledge, data, resources, context and what you can make sense of.

RAG Pipeline GPT Query Runner (Haystack-AI & OpenAI)

For Developers, Technicians & Engineers

This Python script allows you to run a pipeline of Retrieval Augmented Generation (RAG) using the Haystack framework, either by downloading a file from a URL or by using a local file as RAG source. Using the RAG pipeline and an OpenAI query the script output the result. The result can be saved to a file or printed directly to the console.

Requirements

Ensure that you have the necessary Python packages installed. You can do this by using the requirements.txt provided.

pip install -r requirements.txt

Usage

You can run the script from the command line with the following arguments:

Required Arguments

  • --api_key: Your OpenAI API Key.
  • --query: The query you want to send to the pipeline.

Optional Arguments

  • --url: The URL to download the prompt data file from. If provided without a --file_path, the file will be saved as prompt-data.txt by default.
  • --file_path: The path to the prompt data file. This can be either a path to a file you've downloaded manually or where the file will be saved if a URL is provided.
  • --output_name: The name of the output file where the result will be saved. If not provided, the result will be printed to the console.

Examples

  1. Using a URL:

    python query_with_rag_data.py --api_key "your_openai_api_key" --query "Provide me a response to (xyz)" --url "https://www.web.domain/hosted/prompt-data.txt"
    
  2. Using a Local File Path:

    python query_with_rag_data.py --api_key "your_openai_api_key" --query "Provide me a response to (xyz)" --file_path "prompt-data.txt"
    
  3. Saving the Output to a File:

    python query_with_rag_data.py --api_key "your_openai_api_key" --query "Provide me a response to (xyz)" --file_path "prompt-data.txt" --output_name "response.txt"
    

Logging

The script uses Python's built-in logging module to log important actions and errors. These logs can help you understand what the script is doing and troubleshoot any issues.

Error Handling

The script includes comprehensive error handling to manage common issues, such as missing or incorrect inputs. It provides informative error messages to guide you through fixing them.

How It Works

  1. Set the OpenAI API Key:

    • The script begins by setting the OpenAI API key as an environment variable, which is required for pipeline execution.
  2. Handle Input File:

    • If a URL is provided, the script downloads the file to the specified or default file path.
    • If a local file path is provided, the script validates that the file exists.
  3. Pipeline Initialization:

    • The script initializes two pipelines: an indexing pipeline and a RAG (Retrieval-Augmented Generation) pipeline.
  4. Pipeline Execution:

    • The indexing pipeline processes the content of the file.
    • The RAG pipeline generates a response based on the provided query.
  5. Output:

    • The generated response is either saved to a specified file or printed to the console if no file name is provided.

Additional Notes

  • Ensure that your input files are formatted correctly to be processed by the indexing pipeline.
  • The script is designed to be simple yet robust, handling typical use cases effectively.
import os
import logging
import argparse
from haystack import Pipeline, PredefinedPipeline
import urllib.request
# Setup logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
def set_openai_api_key(api_key):
"""
Set the OpenAI API key as an environment variable.
Args:
api_key (str): The OpenAI API key.
Raises:
ValueError: If the API key is not provided.
"""
if not api_key:
logger.error("OpenAI API Key is required.")
raise ValueError("OpenAI API Key is required.")
os.environ["OPENAI_API_KEY"] = api_key
logger.info("OpenAI API Key set successfully.")
def validate_file_path(file_path):
"""
Validate that the provided file path exists.
Args:
file_path (str): The path to the file.
Raises:
FileNotFoundError: If the file does not exist.
"""
if not os.path.exists(file_path):
logger.error(f"File {file_path} does not exist.")
raise FileNotFoundError(f"File {file_path} does not exist.")
logger.info(f"File {file_path} exists and is valid.")
def download_file(url, filename):
"""
Download a file from a given URL and save it to the specified filename.
Args:
url (str): The URL to download the file from.
filename (str): The local path where the file should be saved.
Raises:
Exception: If the file download fails.
"""
try:
logger.info(f"Downloading file from {url} to {filename}")
urllib.request.urlretrieve(url, filename)
logger.info("File downloaded successfully.")
except Exception as e:
logger.error(f"Failed to download file: {e}")
raise
def initialize_pipeline(pipeline_type):
"""
Initialize the specified pipeline.
Args:
pipeline_type (str): The type of pipeline to initialize ('indexing' or 'rag').
Returns:
Pipeline: The initialized pipeline.
Raises:
ValueError: If an invalid pipeline type is provided.
Exception: If pipeline initialization fails.
"""
try:
logger.info(f"Initializing {pipeline_type} pipeline.")
if pipeline_type == 'indexing':
pipeline = Pipeline.from_template(PredefinedPipeline.INDEXING)
elif pipeline_type == 'rag':
pipeline = Pipeline.from_template(PredefinedPipeline.RAG)
else:
logger.error("Invalid pipeline type provided.")
raise ValueError("Invalid pipeline type provided.")
return pipeline
except Exception as e:
logger.error(f"Failed to initialize pipeline: {e}")
raise
def run_indexing_pipeline(indexing_pipeline, file_path):
"""
Run the indexing pipeline on the provided file.
Args:
indexing_pipeline (Pipeline): The initialized indexing pipeline.
file_path (str): The path to the file to be indexed.
Raises:
Exception: If there is an error during pipeline execution.
"""
try:
validate_file_path(file_path)
logger.info("Running the indexing pipeline.")
with open(file_path, 'r') as file:
content = file.read()
indexing_pipeline.run(data={"documents": [{"name": file_path, "content": content}]})
logger.info("Indexing pipeline executed successfully.")
except Exception as e:
logger.error(f"Error during indexing pipeline execution: {e}")
raise
def run_rag_pipeline(rag_pipeline, query):
"""
Run the RAG (Retrieval-Augmented Generation) pipeline with the provided query.
Args:
rag_pipeline (Pipeline): The initialized RAG pipeline.
query (str): The query to be processed by the RAG pipeline.
Returns:
str: The response generated by the RAG pipeline.
Raises:
Exception: If there is an error during pipeline execution.
"""
try:
logger.info("Running the RAG pipeline.")
result = rag_pipeline.run(data={"prompt_builder": {"query": query}, "text_embedder": {"text": query}})
if "answers" in result and len(result["answers"]) > 0:
response = result["answers"][0].answer
logger.info("RAG pipeline executed successfully.")
else:
response = "No response generated."
logger.warning("RAG pipeline did not generate any answers.")
return response
except Exception as e:
logger.error(f"Error during RAG pipeline execution: {e}")
raise
def save_output(response, output_name):
"""
Save the generated response to a file or print it to the console.
Args:
response (str): The response to be saved or printed.
output_name (str): The name of the file to save the response. If None, print to the console.
"""
if output_name:
try:
with open(output_name, 'w') as output_file:
logger.info(f"Saving result to {output_name}.")
output_file.write(response)
except Exception as e:
logger.error(f"Failed to save output: {e}")
raise
else:
logger.info("Output file not provided, printing result to console.")
print(response)
def main(api_key, query, url=None, file_path=None, output_name=None):
"""
Main function to orchestrate the pipeline execution.
Args:
api_key (str): The OpenAI API key.
query (str): The query to be processed by the pipeline.
url (str, optional): URL to download the file from.
file_path (str, optional): Path to the local file.
output_name (str, optional): Name of the file to save the response.
"""
# Set the OpenAI API key
set_openai_api_key(api_key)
# Handle URL vs. file path input
if url:
if not file_path:
file_path = "prompt-data.txt"
download_file(url, file_path)
elif file_path:
validate_file_path(file_path)
else:
logger.error("Either a URL or a file path must be provided.")
raise ValueError("Either a URL or a file path must be provided.")
# Initialize and run the pipelines
indexing_pipeline = initialize_pipeline('indexing')
run_indexing_pipeline(indexing_pipeline, file_path)
rag_pipeline = initialize_pipeline('rag')
response = run_rag_pipeline(rag_pipeline, query)
# Save or print the output
save_output(response, output_name)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Run a pipeline with a query, URL, or file input.")
parser.add_argument("--api_key", required=True, help="OpenAI API Key")
parser.add_argument("--query", required=True, help="Query to send to the pipeline")
parser.add_argument("--url", help="URL to download the prompt data file from")
parser.add_argument("--file_path", help="Path to the prompt data file")
parser.add_argument("--output_name", help="Name of the output file to save the result")
args = parser.parse_args()
main(api_key=args.api_key, query=args.query, url=args.url, file_path=args.file_path, output_name=args.output_name)
haystack-ai==2.4.0
urllib3==2.2.2
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment