Created
October 23, 2024 14:13
-
-
Save pleabargain/8b3f1641ef727cc114ac389cbc1b354b to your computer and use it in GitHub Desktop.
using Ollama to interate over a source of truth and present prompts and responses to an expert
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import csv | |
import json | |
import requests | |
from pathlib import Path | |
from typing import List, Dict | |
import sys | |
from datetime import datetime | |
# Global constants | |
OLLAMA_MODEL = "mistral:latest" | |
OLLAMA_URL = "http://127.0.0.1:11434" | |
SOURCE_TEXT = "" # Will store the original text | |
def get_source_text() -> str: | |
"""Get the source text file location and read the content.""" | |
while True: | |
file_location = input("Please enter the path to the original text file: ").strip() | |
if Path(file_location).is_file(): | |
try: | |
with open(file_location, 'r', encoding='utf-8') as file: | |
content = file.read() | |
print("\nSuccessfully loaded source text.") | |
return content | |
except Exception as e: | |
print(f"Error reading source file: {e}") | |
else: | |
print("File not found. Please try again.") | |
def get_model_info() -> Dict: | |
"""Get detailed information about the Ollama model.""" | |
try: | |
response = requests.post(f"{OLLAMA_URL}/api/show", json={"name": OLLAMA_MODEL}) | |
if response.status_code == 200: | |
return response.json() | |
return {"error": f"Failed to get model info: {response.status_code}"} | |
except Exception as e: | |
return {"error": f"Failed to get model info: {str(e)}"} | |
def print_separator(): | |
"""Print a separator line for better readability.""" | |
print("\n" + "=" * 50 + "\n") | |
def get_file_location() -> str: | |
"""Prompt user for CSV file location and validate it exists.""" | |
while True: | |
file_location = input("Please enter the path to your CSV file: ").strip() | |
if Path(file_location).is_file(): | |
return file_location | |
print("File not found. Please try again.") | |
def read_csv_file(file_location: str) -> List[Dict]: | |
"""Read CSV file and return list of dictionaries.""" | |
try: | |
with open(file_location, 'r', encoding='utf-8') as file: | |
reader = csv.DictReader(file) | |
return list(reader) | |
except Exception as e: | |
print(f"Error reading CSV file: {e}") | |
sys.exit(1) | |
def truncate_text(text: str, max_words: int) -> str: | |
"""Truncate text to specified number of words.""" | |
words = text.split() | |
if len(words) <= max_words: | |
return text | |
return ' '.join(words[:max_words]) + "..." | |
def generate_short_prompt(row: Dict, index: int, total_rows: int) -> str: | |
"""Generate a short question based on row content and its position in the narrative.""" | |
content = ' '.join(str(value) for value in row.values() if value) | |
narrative_position = f"This content is from {'the beginning of' if index == 1 else 'the end of' if index == total_rows else 'the middle of'} the story (section {index} of {total_rows})." | |
base_prompt = """As an expert on Mary Shelley's Frankenstein, and referring ONLY to the provided source text, | |
generate a single, focused question that a high school student might ask (less than 30 words). | |
The question should relate to concepts that can be answered using DIRECT QUOTES from the source text. | |
IMPORTANT: {narrative_position} Consider this narrative context when forming your question. | |
Your question should be appropriate for this point in the story and not reference events that haven't occurred yet. | |
Source text: | |
''' | |
{source_text} | |
''' | |
Based on this row content: """ | |
return base_prompt.format(source_text=SOURCE_TEXT, narrative_position=narrative_position) + content | |
def query_ollama(prompt: str, index: int = None, total_rows: int = None) -> Dict: | |
"""Query Ollama API with the generated prompt.""" | |
narrative_context = "" | |
if index is not None and total_rows is not None: | |
narrative_context = f"""The content being analyzed is from a sequential reading of the book, | |
specifically section {index} of {total_rows}. Only reference events and quotes that would be known | |
at this point in the narrative. Do not reference future events or quote from later sections.""" | |
system_context = f"""You are an expert on Mary Shelley's Frankenstein. You must ONLY use direct quotes | |
from the provided source text. Never fabricate or paraphrase quotes. If you cannot find a relevant | |
direct quote, say so explicitly. Always cite the location of quotes when possible. | |
{narrative_context}""" | |
payload = { | |
"model": OLLAMA_MODEL, | |
"prompt": f"{system_context}\n\n{prompt}", | |
"stream": False, | |
"max_tokens": 50 | |
} | |
try: | |
response = requests.post(f"{OLLAMA_URL}/api/generate", json=payload) | |
response.raise_for_status() | |
return response.json() | |
except requests.exceptions.RequestException as e: | |
print(f"Error querying Ollama: {e}") | |
return {"error": str(e)} | |
def save_to_json(data: Dict, filename: str): | |
"""Append data to a JSON file.""" | |
try: | |
# Read existing data | |
if Path(filename).exists(): | |
with open(filename, 'r', encoding='utf-8') as f: | |
try: | |
existing_data = json.load(f) | |
except json.JSONDecodeError: | |
existing_data = [] | |
else: | |
existing_data = [] | |
# Append new data | |
existing_data.append(data) | |
# Write back to file | |
with open(filename, 'w', encoding='utf-8') as f: | |
json.dump(existing_data, f, indent=2, ensure_ascii=False) | |
except Exception as e: | |
print(f"Error saving to {filename}: {e}") | |
def get_rejection_reason() -> str: | |
"""Get the reason for rejection from user.""" | |
print("\nWhy was this response rejected? (Press Enter to skip)") | |
return input("Reason: ").strip() | |
def get_feedback_reason(is_accepted: bool) -> str: | |
"""Get the reason for acceptance or rejection from user.""" | |
if is_accepted: | |
print("\nWhy did you accept this response? (Press Enter to skip)") | |
else: | |
print("\nWhy did you reject this response? (Press Enter to skip)") | |
return input("Reason: ").strip() | |
def display_row_and_get_feedback(row: Dict, index: int, total: int, model_info: Dict) -> Dict: | |
"""Display CSV row, prompt, and response, and get user feedback.""" | |
print_separator() | |
print(f"Processing Row {index}/{total}") | |
print(f"Using model: {OLLAMA_MODEL}") | |
print(f"Narrative Position: Section {index} of {total}") | |
print("\nROW CONTENT:") | |
print("-" * 20) | |
for key, value in row.items(): | |
print(f"{key}: {value}") | |
# Generate and display prompt | |
initial_prompt = generate_short_prompt(row, index, total) | |
prompt_response = query_ollama(initial_prompt, index, total) | |
question = truncate_text(prompt_response.get("response", ""), 50) | |
print("\nGENERATED QUESTION:") | |
print("-" * 20) | |
print(question) | |
# Generate and display short response | |
response_prompt = f"""Provide a brief answer to this question using ONLY direct quotes from the source text. | |
You are an expert teacher speaking to high school students. Include the quote location if possible. | |
IMPORTANT: This question is about section {index} of {total} in the narrative. Only reference events | |
and quotes that would be known at this point in the story. Do not spoil future events. | |
Source text: | |
''' | |
{SOURCE_TEXT} | |
''' | |
Question: {question}""" | |
answer_response = query_ollama(response_prompt, index, total) | |
answer = truncate_text(answer_response.get("response", ""), 100) | |
print("\nGENERATED RESPONSE:") | |
print("-" * 20) | |
print(answer) | |
# [Rest of the function remains the same but add narrative position to result] | |
while True: | |
feedback = input("\nAccept this prompt and response? (y/n/q to quit): ").lower() | |
if feedback in ['y', 'n', 'q']: | |
break | |
print("Please enter 'y' for yes, 'n' for no, or 'q' to quit.") | |
if feedback == 'q': | |
raise KeyboardInterrupt | |
result = { | |
"timestamp": datetime.now().isoformat(), | |
"row_data": row, | |
"narrative_position": f"Section {index} of {total}", | |
"question": question, | |
"response": answer, | |
"accepted": feedback == 'y', | |
"model": { | |
"name": OLLAMA_MODEL, | |
"details": model_info | |
} | |
} | |
# If rejected, get reason | |
if not result["accepted"]: | |
reason = get_rejection_reason() | |
result["rejection_reason"] = reason | |
return result | |
def process_csv_with_ollama(file_location: str): | |
"""Process CSV file and generate prompts/responses.""" | |
rows = read_csv_file(file_location) | |
total_rows = len(rows) | |
accepted_count = 0 | |
rejected_count = 0 | |
print(f"\nProcessing {total_rows} rows...") | |
# Get model information once at the start | |
model_info = get_model_info() | |
print(f"Using model: {OLLAMA_MODEL}") | |
if "error" in model_info: | |
print(f"Warning: {model_info['error']}") | |
try: | |
for index, row in enumerate(rows, 1): | |
result = display_row_and_get_feedback(row, index, total_rows, model_info) | |
if result["accepted"]: | |
save_to_json(result, "response.json") | |
accepted_count += 1 | |
else: | |
save_to_json(result, "bad_response.json") | |
rejected_count += 1 | |
except KeyboardInterrupt: | |
print("\nProcess interrupted by user") | |
finally: | |
# Display summary | |
processed_rows = accepted_count + rejected_count | |
if processed_rows > 0: | |
print(f"\nProcessing Summary:") | |
print(f"Total rows processed: {processed_rows}/{total_rows}") | |
print(f"Accepted responses: {accepted_count}") | |
print(f"Rejected responses: {rejected_count}") | |
if processed_rows > 0: | |
acceptance_rate = (accepted_count / processed_rows) * 100 | |
print(f"Acceptance rate: {acceptance_rate:.1f}%") | |
def main(): | |
print("Frankenstein Text Analyzer") | |
print("=========================") | |
print(f"Using Ollama Model: {OLLAMA_MODEL}") | |
try: | |
# Get source text first | |
global SOURCE_TEXT | |
SOURCE_TEXT = get_source_text() | |
# Get CSV file location | |
file_location = get_file_location() | |
# Process the CSV file | |
process_csv_with_ollama(file_location) | |
except KeyboardInterrupt: | |
print("\nProcess interrupted by user") | |
sys.exit(1) | |
except Exception as e: | |
print(f"\nAn error occurred: {e}") | |
sys.exit(1) | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment