|
import os |
|
import logging |
|
import argparse |
|
from haystack import Pipeline, PredefinedPipeline |
|
import urllib.request |
|
|
|
# Setup logging |
|
logging.basicConfig(level=logging.INFO) |
|
logger = logging.getLogger(__name__) |
|
|
|
def set_openai_api_key(api_key): |
|
""" |
|
Set the OpenAI API key as an environment variable. |
|
|
|
Args: |
|
api_key (str): The OpenAI API key. |
|
|
|
Raises: |
|
ValueError: If the API key is not provided. |
|
""" |
|
if not api_key: |
|
logger.error("OpenAI API Key is required.") |
|
raise ValueError("OpenAI API Key is required.") |
|
os.environ["OPENAI_API_KEY"] = api_key |
|
logger.info("OpenAI API Key set successfully.") |
|
|
|
def validate_file_path(file_path): |
|
""" |
|
Validate that the provided file path exists. |
|
|
|
Args: |
|
file_path (str): The path to the file. |
|
|
|
Raises: |
|
FileNotFoundError: If the file does not exist. |
|
""" |
|
if not os.path.exists(file_path): |
|
logger.error(f"File {file_path} does not exist.") |
|
raise FileNotFoundError(f"File {file_path} does not exist.") |
|
logger.info(f"File {file_path} exists and is valid.") |
|
|
|
def download_file(url, filename): |
|
""" |
|
Download a file from a given URL and save it to the specified filename. |
|
|
|
Args: |
|
url (str): The URL to download the file from. |
|
filename (str): The local path where the file should be saved. |
|
|
|
Raises: |
|
Exception: If the file download fails. |
|
""" |
|
try: |
|
logger.info(f"Downloading file from {url} to {filename}") |
|
urllib.request.urlretrieve(url, filename) |
|
logger.info("File downloaded successfully.") |
|
except Exception as e: |
|
logger.error(f"Failed to download file: {e}") |
|
raise |
|
|
|
def initialize_pipeline(pipeline_type): |
|
""" |
|
Initialize the specified pipeline. |
|
|
|
Args: |
|
pipeline_type (str): The type of pipeline to initialize ('indexing' or 'rag'). |
|
|
|
Returns: |
|
Pipeline: The initialized pipeline. |
|
|
|
Raises: |
|
ValueError: If an invalid pipeline type is provided. |
|
Exception: If pipeline initialization fails. |
|
""" |
|
try: |
|
logger.info(f"Initializing {pipeline_type} pipeline.") |
|
if pipeline_type == 'indexing': |
|
pipeline = Pipeline.from_template(PredefinedPipeline.INDEXING) |
|
elif pipeline_type == 'rag': |
|
pipeline = Pipeline.from_template(PredefinedPipeline.RAG) |
|
else: |
|
logger.error("Invalid pipeline type provided.") |
|
raise ValueError("Invalid pipeline type provided.") |
|
return pipeline |
|
except Exception as e: |
|
logger.error(f"Failed to initialize pipeline: {e}") |
|
raise |
|
|
|
def run_indexing_pipeline(indexing_pipeline, file_path): |
|
""" |
|
Run the indexing pipeline on the provided file. |
|
|
|
Args: |
|
indexing_pipeline (Pipeline): The initialized indexing pipeline. |
|
file_path (str): The path to the file to be indexed. |
|
|
|
Raises: |
|
Exception: If there is an error during pipeline execution. |
|
""" |
|
try: |
|
validate_file_path(file_path) |
|
logger.info("Running the indexing pipeline.") |
|
with open(file_path, 'r') as file: |
|
content = file.read() |
|
indexing_pipeline.run(data={"documents": [{"name": file_path, "content": content}]}) |
|
logger.info("Indexing pipeline executed successfully.") |
|
except Exception as e: |
|
logger.error(f"Error during indexing pipeline execution: {e}") |
|
raise |
|
|
|
def run_rag_pipeline(rag_pipeline, query): |
|
""" |
|
Run the RAG (Retrieval-Augmented Generation) pipeline with the provided query. |
|
|
|
Args: |
|
rag_pipeline (Pipeline): The initialized RAG pipeline. |
|
query (str): The query to be processed by the RAG pipeline. |
|
|
|
Returns: |
|
str: The response generated by the RAG pipeline. |
|
|
|
Raises: |
|
Exception: If there is an error during pipeline execution. |
|
""" |
|
try: |
|
logger.info("Running the RAG pipeline.") |
|
result = rag_pipeline.run(data={"prompt_builder": {"query": query}, "text_embedder": {"text": query}}) |
|
if "answers" in result and len(result["answers"]) > 0: |
|
response = result["answers"][0].answer |
|
logger.info("RAG pipeline executed successfully.") |
|
else: |
|
response = "No response generated." |
|
logger.warning("RAG pipeline did not generate any answers.") |
|
return response |
|
except Exception as e: |
|
logger.error(f"Error during RAG pipeline execution: {e}") |
|
raise |
|
|
|
def save_output(response, output_name): |
|
""" |
|
Save the generated response to a file or print it to the console. |
|
|
|
Args: |
|
response (str): The response to be saved or printed. |
|
output_name (str): The name of the file to save the response. If None, print to the console. |
|
""" |
|
if output_name: |
|
try: |
|
with open(output_name, 'w') as output_file: |
|
logger.info(f"Saving result to {output_name}.") |
|
output_file.write(response) |
|
except Exception as e: |
|
logger.error(f"Failed to save output: {e}") |
|
raise |
|
else: |
|
logger.info("Output file not provided, printing result to console.") |
|
print(response) |
|
|
|
def main(api_key, query, url=None, file_path=None, output_name=None): |
|
""" |
|
Main function to orchestrate the pipeline execution. |
|
|
|
Args: |
|
api_key (str): The OpenAI API key. |
|
query (str): The query to be processed by the pipeline. |
|
url (str, optional): URL to download the file from. |
|
file_path (str, optional): Path to the local file. |
|
output_name (str, optional): Name of the file to save the response. |
|
""" |
|
# Set the OpenAI API key |
|
set_openai_api_key(api_key) |
|
|
|
# Handle URL vs. file path input |
|
if url: |
|
if not file_path: |
|
file_path = "prompt-data.txt" |
|
download_file(url, file_path) |
|
elif file_path: |
|
validate_file_path(file_path) |
|
else: |
|
logger.error("Either a URL or a file path must be provided.") |
|
raise ValueError("Either a URL or a file path must be provided.") |
|
|
|
# Initialize and run the pipelines |
|
indexing_pipeline = initialize_pipeline('indexing') |
|
run_indexing_pipeline(indexing_pipeline, file_path) |
|
|
|
rag_pipeline = initialize_pipeline('rag') |
|
response = run_rag_pipeline(rag_pipeline, query) |
|
|
|
# Save or print the output |
|
save_output(response, output_name) |
|
|
|
if __name__ == "__main__": |
|
parser = argparse.ArgumentParser(description="Run a pipeline with a query, URL, or file input.") |
|
parser.add_argument("--api_key", required=True, help="OpenAI API Key") |
|
parser.add_argument("--query", required=True, help="Query to send to the pipeline") |
|
parser.add_argument("--url", help="URL to download the prompt data file from") |
|
parser.add_argument("--file_path", help="Path to the prompt data file") |
|
parser.add_argument("--output_name", help="Name of the output file to save the result") |
|
|
|
args = parser.parse_args() |
|
|
|
main(api_key=args.api_key, query=args.query, url=args.url, file_path=args.file_path, output_name=args.output_name) |