Created
June 3, 2025 17:24
-
-
Save CultriX-Github/9d53565214d56b12b9002a56230d1c00 to your computer and use it in GitHub Desktop.
Script for QA-style dataset generation from custom data:
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3 | |
""" | |
Refactored Q&A Dataset Generation Script | |
======================================== | |
Features: | |
- Separate configuration for generator vs. judge (API keys, endpoints, and models). | |
- EnvironmentΓÇÉvariable and CLIΓÇÉdriven configuration. | |
- Consistent use of pathlib for file paths. | |
- Modular logging with debug mode. | |
- PerΓÇÉfile ΓÇ£.inprogressΓÇ¥ checkpoint + cumulative checkpoint for resume. | |
- tqdm progress bars. | |
- Consolidated prompt templates for multiΓÇÉlanguage support. | |
- Exponential backoff on rate limits. | |
- Incremental export to CSV/JSON/XLSX to minimize data loss on crash. | |
Version: 3.1 | |
Author: [CultriX] | |
""" | |
import os | |
import sys | |
import json | |
import time | |
import argparse | |
import logging | |
from pathlib import Path | |
from typing import List, Dict, Tuple, Optional, Any | |
from dataclasses import dataclass, field | |
import pandas as pd | |
from tqdm import tqdm | |
from datetime import datetime | |
from openai import OpenAI | |
# ----------------------------------------------------------------------------- | |
# 1. CONFIGURATION DATA CLASS & CLI PARSING | |
# ----------------------------------------------------------------------------- | |
@dataclass | |
class Config: | |
# Generator API settings | |
gen_api_key: str = field(default_factory=lambda: os.getenv("OPENAI_GEN_API_KEY", "")) | |
gen_base_url: str = field(default_factory=lambda: os.getenv("OPENAI_GEN_API_BASE", "http://localhost:1234/v1")) | |
gen_model_name: str = "google/gemma-3-4b" | |
# Judge API settings | |
judge_api_key: Optional[str] = field(default=None) | |
judge_base_url: Optional[str] = field(default=None) | |
judge_model_name: Optional[str] = field(default=None) | |
# Generation parameters | |
question_max_tokens: int = 100 | |
answer_max_tokens: int = 350 | |
request_delay: float = 1.0 # seconds between requests | |
max_retries: int = 3 | |
# File paths (input/output directories) | |
input_directory: Path = Path("rag-input") | |
output_directory: Path = Path("rag-output") | |
output_filename: str = "qa_dataset" | |
# Validation options | |
resume_from_checkpoint: bool = True | |
validate_qa_quality: bool = True | |
llm_judge_quality: bool = True | |
min_question_length: int = 10 | |
min_answer_length: int = 20 | |
# Language: 'en' or 'nl' | |
language: str = "en" | |
# Chunking parameters | |
chunk_by_paragraphs: bool = True | |
min_chunk_length_chars: int = 150 | |
# Logging and debug | |
debug: bool = False | |
def __post_init__(self): | |
# If judge settings are not provided, default to generator settings | |
if not self.judge_api_key: | |
self.judge_api_key = self.gen_api_key | |
if not self.judge_base_url: | |
self.judge_base_url = self.gen_base_url | |
if not self.judge_model_name: | |
self.judge_model_name = self.gen_model_name | |
def parse_args() -> Config: | |
parser = argparse.ArgumentParser( | |
description="Generate a Q&A dataset from a directory of text files." | |
) | |
# File/dir arguments | |
parser.add_argument( | |
"--input-dir", "-i", type=Path, default=Path("rag-input"), | |
help="Directory containing .txt files to process." | |
) | |
parser.add_argument( | |
"--output-dir", "-o", type=Path, default=Path("rag-output"), | |
help="Directory where output CSV/JSON/XLSX will be written." | |
) | |
parser.add_argument( | |
"--output-filename", "-n", type=str, default="qa_dataset", | |
help="Base filename (without extension) for all exports." | |
) | |
# Generator API/Model | |
parser.add_argument( | |
"--gen-model", type=str, default="google/gemma-3-4b", | |
help="Name of the LLM model to use for QA generation." | |
) | |
parser.add_argument( | |
"--gen-api-base", type=str, | |
default=os.getenv("OPENAI_GEN_API_BASE", "http://localhost:1234/v1"), | |
help="Base URL for the generatorΓÇÖs OpenAIΓÇÉcompatible endpoint." | |
) | |
parser.add_argument( | |
"--gen-api-key", type=str, default=os.getenv("OPENAI_GEN_API_KEY", ""), | |
help="OpenAI API key for the generator (or set env var OPENAI_GEN_API_KEY)." | |
) | |
# Judge API/Model | |
parser.add_argument( | |
"--judge-model", type=str, default=None, | |
help="Name of the LLM model to use for QA judgment (defaults to generator model if not set)." | |
) | |
parser.add_argument( | |
"--judge-api-base", type=str, default=None, | |
help="Base URL for the judgeΓÇÖs OpenAIΓÇÉcompatible endpoint (defaults to generator endpoint)." | |
) | |
parser.add_argument( | |
"--judge-api-key", type=str, default=None, | |
help="OpenAI API key for the judge (defaults to generator key if not set)." | |
) | |
# Chunking / Validation | |
parser.add_argument( | |
"--min-chunk-len", type=int, default=150, | |
help="Minimum number of characters required to treat a paragraph as a chunk." | |
) | |
parser.add_argument( | |
"--question-chars", type=int, default=10, | |
help="Minimum number of characters in a generated question for ruleΓÇÉbased validation." | |
) | |
parser.add_argument( | |
"--answer-chars", type=int, default=20, | |
help="Minimum number of characters in a generated answer for ruleΓÇÉbased validation." | |
) | |
parser.add_argument( | |
"--no-validate", dest="validate_qa_quality", action="store_false", | |
help="Disable ruleΓÇÉbased Q&A validation." | |
) | |
parser.add_argument( | |
"--no-judge", dest="llm_judge_quality", action="store_false", | |
help="Disable LLMΓÇÉbased Q&A quality judgment." | |
) | |
# Language | |
parser.add_argument( | |
"--lang", "-l", type=str, choices=["en", "nl"], default="en", | |
help="Language for prompts (en or nl)." | |
) | |
# Resume and logging | |
parser.add_argument( | |
"--no-resume", dest="resume_from_checkpoint", action="store_false", | |
help="Do not resume from a previous checkpoint; always start fresh." | |
) | |
parser.add_argument( | |
"--debug", "-d", action="store_true", help="Enable debug logging." | |
) | |
args = parser.parse_args() | |
cfg = Config( | |
gen_api_key=args.gen_api_key, | |
gen_base_url=args.gen_api_base, | |
gen_model_name=args.gen_model, | |
judge_api_key=args.judge_api_key, | |
judge_base_url=args.judge_api_base, | |
judge_model_name=args.judge_model, | |
input_directory=args.input_dir, | |
output_directory=args.output_dir, | |
output_filename=args.output_filename, | |
min_question_length=args.question_chars, | |
min_answer_length=args.answer_chars, | |
resume_from_checkpoint=args.resume_from_checkpoint, | |
validate_qa_quality=args.validate_qa_quality, | |
llm_judge_quality=args.llm_judge_quality, | |
language=args.lang, | |
min_chunk_length_chars=args.min_chunk_len, | |
debug=args.debug | |
) | |
return cfg | |
# ----------------------------------------------------------------------------- | |
# 2. LOGGING SETUP | |
# ----------------------------------------------------------------------------- | |
def setup_logger(output_dir: Path, debug: bool) -> logging.Logger: | |
log_dir = output_dir / "logs" | |
log_dir.mkdir(parents=True, exist_ok=True) | |
log_file = log_dir / f"qa_generation_{datetime.now().strftime('%Y%m%d_%H%M%S')}.log" | |
log_level = logging.DEBUG if debug else logging.INFO | |
handlers = [ | |
logging.FileHandler(log_file, encoding="utf-8"), | |
logging.StreamHandler(sys.stdout) | |
] | |
logging.basicConfig( | |
level=log_level, | |
format="%(asctime)s - %(levelname)s - %(message)s", | |
handlers=handlers | |
) | |
logger = logging.getLogger("QAGenerator") | |
logger.info(f"Logger initialized at level: {'DEBUG' if debug else 'INFO'}") | |
return logger | |
# ----------------------------------------------------------------------------- | |
# 3. PROMPT TEMPLATES | |
# ----------------------------------------------------------------------------- | |
PROMPT_TEMPLATES = { | |
"en": { | |
"question": ( | |
"You are a Professor writing an exam. Using ONLY the provided context below, " | |
"formulate a single, clear, and specific question that captures an important " | |
"fact or insight from THIS context. The question must be answerable solely from " | |
"THIS context, ensuring the answer is entirely contained within this context.\n\n" | |
"Context: \"{context}\"\n\n" | |
"Generate a question similar to these examples:\n" | |
"- \"Who was Aristophanes based on this text?\"\n" | |
"- \"What are latifundia as described here?\"\n" | |
"- \"What is ostracism according to this passage?\"\n\n" | |
"Important:\n" | |
"- Base the question STRICTLY on the provided context. Do not use external knowledge.\n" | |
"- The question must be answerable solely from THIS context.\n" | |
"- End with a single question mark.\n" | |
"- Keep it concise but clear.\n" | |
"- Do not output any text before the question itself.\n" | |
"- Do not output any text after the question mark.\n\n" | |
"Question:" | |
), | |
"answer": ( | |
"Given ONLY the following context, provide a detailed, complete answer to the question. " | |
"Do not add any external knowledge or make assumptions beyond the text.\n\n" | |
"Context: \"{context}\"\n\n" | |
"Question: \"{question}\"\n\n" | |
"Instructions:\n" | |
"- Answer the question directly and completely using ONLY information from the provided context.\n" | |
"- Be specific and factual based on THIS context.\n" | |
"- Avoid phrases like \"According to the text\", \"The context states\", etc. Just provide the answer.\n" | |
"- Provide the answer directly without any preamble.\n\n" | |
"Answer:" | |
), | |
"judge": ( | |
"\nYou are an expert quality assurance reviewer for a Q&A dataset. " | |
"Your task is to evaluate a given Question and its Answer based STRICTLY on the provided Context.\n\n" | |
"Context: \"{context}\"\n\n" | |
"Question: \"{question}\"\n\n" | |
"Answer: \"{answer}\"\n\n" | |
"Evaluate the following criteria:\n" | |
"1. **Relevance**: Is the Question directly relevant to the Context? (Yes/No)\n" | |
"2. **Answerability**: Is the Question fully answerable *solely* from the provided Context? (Yes/No)\n" | |
"3. **Accuracy**: Is the Answer factually correct based on the provided Context? (Yes/No)\n" | |
"4. **Completeness**: Does the Answer fully address the Question using information from the Context, " | |
"without adding external knowledge or making assumptions? (Yes/No)\n" | |
"5. **Conciseness**: Is the Answer free from unnecessary preamble, conversational filler, or irrelevant information? (Yes/No)\n\n" | |
"Based on these evaluations, state 'GOOD' if ALL criteria are met (Yes for all). " | |
"Otherwise, state 'BAD' and briefly explain which criteria failed.\n\n" | |
"Example GOOD output:\nGOOD\n\n" | |
"Example BAD output:\n" | |
"BAD - Answerability: The question asks about X, but X is not mentioned in the context.\n" | |
"BAD - Completeness: The answer only partially addresses the question; more details about Y were in the context but omitted.\n" | |
"BAD - Accuracy: The answer contains information not found in the context.\n\n" | |
"Judgment:" | |
) | |
}, | |
"nl": { | |
"question": ( | |
"Je bent een Professor die een examen schrijft. Gebruik ALLEEN de onderstaande context om een " | |
"enkele, duidelijke en specifieke vraag te formuleren die een belangrijk feit of inzicht uit DEZE context " | |
"vastlegt. De vraag moet uitsluitend uit DEZE context beantwoord kunnen worden, zodat het antwoord " | |
"volledig in deze context is opgenomen.\n\n" | |
"Context: \"{context}\"\n\n" | |
"Genereer een vraag die vergelijkbaar is met deze voorbeelden:\n" | |
"- \"Wie was Aristophanes op basis van deze tekst?\"\n" | |
"- \"Wat zijn latifundia zoals hier beschreven?\"\n" | |
"- \"Wat is ostracisme volgens deze passage?\"\n\n" | |
"Belangrijk:\n" | |
"- Baseer de vraag STRIKT op de geleverde context. Gebruik geen externe kennis.\n" | |
"- De vraag moet uitsluitend uit DEZE context beantwoord kunnen worden.\n" | |
"- Eindig met één enkel vraagteken.\n" | |
"- Houd het beknopt maar duidelijk.\n" | |
"- Geef geen tekst v├│├│r de vraag zelf.\n" | |
"- Geef geen tekst na het vraagteken.\n\n" | |
"Vraag:" | |
), | |
"answer": ( | |
"Gegeven ALLEEN de volgende context, geef een gedetailleerd en volledig antwoord op de vraag. " | |
"Voeg geen externe kennis toe en doe geen aannames buiten de tekst.\n\n" | |
"Context: \"{context}\"\n\n" | |
"Vraag: \"{question}\"\n\n" | |
"Instructies:\n" | |
"- Beantwoord de vraag direct en volledig met ALLEEN informatie uit de geleverde context.\n" | |
"- Wees specifiek en feitelijk gebaseerd op DEZE context.\n" | |
"- Vermijd zinnen zoals \"Volgens de tekst\", \"De context vermeldt\", enz. Geef gewoon het antwoord.\n" | |
"- Geef het antwoord direct zonder inleiding.\n\n" | |
"Antwoord:" | |
), | |
"judge": ( | |
"\nU bent een deskundige kwaliteitscontroleur voor een Q&A-dataset. Uw taak is om een gegeven Vraag en Antwoord " | |
"STRICT te evalueren op basis van de geleverde Context.\n\n" | |
"Context: \"{context}\"\n\n" | |
"Vraag: \"{question}\"\n\n" | |
"Antwoord: \"{answer}\"\n\n" | |
"Evalueer de volgende criteria:\n" | |
"1. **Relevantie**: Is de Vraag direct relevant voor de Context? (Ja/Nee)\n" | |
"2. **Beantwoordbaarheid**: Is de Vraag volledig beantwoordbaar *uitsluitend* vanuit de geleverde Context? (Ja/Nee)\n" | |
"3. **Nauwkeurigheid**: Is het Antwoord feitelijk correct op basis van de geleverde Context? (Ja/Nee)\n" | |
"4. **Volledigheid**: Behandelt het Antwoord de Vraag volledig met informatie uit de Context, " | |
"zonder externe kennis toe te voegen of aannames te doen? (Ja/Nee)\n" | |
"5. **Beknoptheid**: Is het Antwoord vrij van onnodige inleiding, conversatievulling of irrelevante informatie? (Ja/Nee)\n\n" | |
"Op basis van deze evaluaties, vermeld 'GOED' als AAN ALLE criteria is voldaan (Ja bij alles). " | |
"Anders, vermeld 'FOUT' en leg kort uit welke criteria niet voldeden.\n\n" | |
"Voorbeeld GOED uitvoer:\nGOED\n\n" | |
"Voorbeeld FOUT uitvoer:\n" | |
"FOUT - Beantwoordbaarheid: De vraag gaat over X, maar X wordt niet genoemd in de context.\n" | |
"FOUT - Volledigheid: Het antwoord behandelt de vraag slechts gedeeltelijk; meer details over Y stonden in de context maar zijn weggelaten.\n" | |
"FOUT - Nauwkeurigheid: Het antwoord bevat informatie die niet in de context is gevonden.\n\n" | |
"Oordeel:" | |
) | |
} | |
} | |
# ----------------------------------------------------------------------------- | |
# 4. QAGenerator CLASS | |
# ----------------------------------------------------------------------------- | |
class QAGenerator: | |
def __init__(self, config: Config): | |
self.config = config | |
self.logger = setup_logger(config.output_directory, config.debug) | |
# Initialize two separate OpenAI clients: | |
self.gen_client = self._initialize_client( | |
api_key=self.config.gen_api_key, | |
base_url=self.config.gen_base_url, | |
name="Generator" | |
) | |
self.judge_client = self._initialize_client( | |
api_key=self.config.judge_api_key, | |
base_url=self.config.judge_base_url, | |
name="Judge" | |
) | |
# Checkpoint files | |
self.ckpt_cumulative = config.output_directory / f"{config.output_filename}_checkpoint.json" | |
self.processed_files: set = self._load_cumulative_checkpoint() | |
# InΓÇÉmemory list for the current run | |
self.qa_data_current_run: List[Dict[str, Any]] = [] | |
# Ensure output directory exists | |
config.output_directory.mkdir(parents=True, exist_ok=True) | |
def _initialize_client(self, api_key: str, base_url: str, name: str) -> OpenAI: | |
if not api_key: | |
self.logger.error(f"No API key provided for {name}. Set the appropriate environment variable or CLI flag.") | |
sys.exit(1) | |
try: | |
client = OpenAI( | |
api_key=api_key, | |
base_url=base_url | |
) | |
self.logger.debug(f"{name} client initialized (model endpoint: {base_url})") | |
return client | |
except Exception as e: | |
self.logger.error(f"Failed to initialize {name} client: {e}") | |
sys.exit(1) | |
def _load_cumulative_checkpoint(self) -> set: | |
"""Load list of already-processed files from cumulative checkpoint.""" | |
if (not self.config.resume_from_checkpoint) or (not self.ckpt_cumulative.exists()): | |
return set() | |
try: | |
with open(self.ckpt_cumulative, "r", encoding="utf-8") as f: | |
data = json.load(f) | |
processed = set(data.get("processed_files", [])) | |
self.logger.info(f"Loaded checkpoint: {len(processed)} files already processed") | |
return processed | |
except Exception as e: | |
self.logger.warning(f"Could not load checkpoint: {e}") | |
return set() | |
def _save_cumulative_checkpoint(self): | |
"""Save the set of processed files to disk.""" | |
try: | |
data = { | |
"processed_files": sorted(self.processed_files), | |
"timestamp": datetime.now().isoformat(), | |
"total_processed": len(self.processed_files), | |
} | |
with open(self.ckpt_cumulative, "w", encoding="utf-8") as f: | |
json.dump(data, f, indent=2) | |
except Exception as e: | |
self.logger.error(f"Failed to save checkpoint: {e}") | |
def _get_text_files(self) -> List[Path]: | |
"""Return list of .txt files in input_directory.""" | |
pattern = self.config.input_directory.glob("*.txt") | |
files = sorted([p for p in pattern if p.is_file()]) | |
if not files: | |
raise FileNotFoundError(f"No .txt files found in {self.config.input_directory}") | |
self.logger.info(f"Found {len(files)} text files to process") | |
return files | |
def _get_text_chunks(self, content: str) -> List[str]: | |
"""Split a fileΓÇÖs content into chunks >= min_chunk_length_chars.""" | |
if not content.strip(): | |
return [] | |
if self.config.chunk_by_paragraphs: | |
raw_chunks = content.split("\n\n") | |
else: | |
raw_chunks = [content] | |
valid_chunks: List[str] = [] | |
for chunk in raw_chunks: | |
chunk_stripped = chunk.strip() | |
if len(chunk_stripped) >= self.config.min_chunk_length_chars: | |
valid_chunks.append(chunk_stripped) | |
else: | |
self.logger.debug( | |
f"Skipping chunk (len {len(chunk_stripped)} < {self.config.min_chunk_length_chars}): " | |
f"'{chunk_stripped[:30]}...'" | |
) | |
return valid_chunks | |
def _strip_surrounding_quotes(self, text: str) -> str: | |
t = text.strip() | |
if len(t) >= 2: | |
if (t[0] in ('"', "'", "ΓÇ£", "ΓÇÿ")) and (t[-1] in ('"', "'", "ΓÇ¥", "ΓÇÖ")): | |
return t[1:-1].strip() | |
return t | |
def _call_llm_with_backoff( | |
self, | |
client: OpenAI, | |
model_name: str, | |
prompt: str, | |
max_tokens: int, | |
temperature: float | |
) -> Optional[str]: | |
""" | |
Call the LLM with up to max_retries. If the exception message contains 'rate limit', | |
apply exponential backoff. | |
""" | |
for attempt in range(1, self.config.max_retries + 1): | |
try: | |
response = client.completions.create( | |
model=model_name, | |
prompt=prompt, | |
max_tokens=max_tokens, | |
temperature=temperature | |
) | |
return response.choices[0].text | |
except Exception as e: | |
msg = str(e).lower() | |
if "rate limit" in msg: | |
wait_secs = self.config.request_delay * (2 ** (attempt - 1)) | |
self.logger.warning( | |
f"Rate limit detected (attempt {attempt}/{self.config.max_retries}). " | |
f"Retrying in {wait_secs:.1f}s..." | |
) | |
time.sleep(wait_secs) | |
continue | |
else: | |
self.logger.error(f"LLM call failed (non-rate-limit): {e}", exc_info=True) | |
return None | |
self.logger.error("Exceeded max retries for LLM call.") | |
return None | |
def _validate_qa_pair( | |
self, | |
question: str, | |
answer: str, | |
context: str | |
) -> bool: | |
"""Rule-based validation of a single Q&A pair.""" | |
if not self.config.validate_qa_quality: | |
self.logger.debug("Rule-based validation is disabled.") | |
return True | |
q = question.strip() | |
a = answer.strip() | |
# Basic length checks | |
if len(q) < self.config.min_question_length: | |
self.logger.warning( | |
f"Validation FAIL: Question too short ({len(q)}/{self.config.min_question_length})." | |
f" Q: '{q}'" | |
) | |
return False | |
if len(a) < self.config.min_answer_length: | |
self.logger.warning( | |
f"Validation FAIL: Answer too short ({len(a)}/{self.config.min_answer_length})." | |
f" A: '{a}'" | |
) | |
return False | |
# (Removed the ΓÇ£must end with '?'ΓÇ¥ check here) | |
# LanguageΓÇÉspecific error indicators | |
indicators = [] | |
if self.config.language == "en": | |
indicators = [ | |
"i cannot", "i can't", "i don't know", "based on the context provided", | |
"according to the text", "the context doesn't", "not mentioned in the context", | |
"the provided context does not contain", "the context does not mention" | |
] | |
else: # 'nl' | |
indicators = [ | |
"ik kan niet", "ik weet het niet", "op basis van de context", | |
"volgens de tekst", "de context vermeldt niet", "niet vermeld in de context", | |
"de gegeven context bevat geen", "de context noemt niet" | |
] | |
for ind in indicators: | |
if ind.lower() in a.lower(): | |
self.logger.warning( | |
f"Validation FAIL: Found error indicator '{ind}' in answer. A: '{a}'" | |
) | |
return False | |
# Keyword overlap check (simple) | |
if not context.strip(): | |
self.logger.warning("Validation FAIL: Empty context.") | |
return False | |
context_words = set(context.lower().split()) | |
answer_words = set(a.lower().split()) | |
overlap = len(context_words.intersection(answer_words)) | |
if overlap < 3: | |
self.logger.warning( | |
f"Validation FAIL: Insufficient keyword overlap ({overlap}/3). A: '{a[:30]}...' " | |
f"Context: '{context[:30]}...'" | |
) | |
return False | |
self.logger.debug("Rule-based validation PASS.") | |
return True | |
def _llm_qa_quality_check( | |
self, | |
context: str, | |
question: str, | |
answer: str | |
) -> bool: | |
"""Ask the judge LLM to judge an existing Q&A pair. Returns True if GOOD.""" | |
if not self.config.llm_judge_quality: | |
self.logger.debug("LLM-based QA judgment is disabled.") | |
return True | |
template = PROMPT_TEMPLATES[self.config.language]["judge"] | |
prompt = template.format(context=context, question=question, answer=answer) | |
raw = self._call_llm_with_backoff( | |
client=self.judge_client, | |
model_name=self.config.judge_model_name, | |
prompt=prompt, | |
max_tokens=150, | |
temperature=0.0 | |
) | |
if raw is None: | |
# Treat as failureΓÇöbetter to filter out a questionable QA than keep a bad one | |
self.logger.error("LLM judgment returned None (treating as BAD).") | |
return False | |
verdict = raw.strip().upper() | |
self.logger.debug(f"LLM judgment: '{verdict.splitlines()[0]}'") | |
if verdict.startswith("GOOD"): | |
return True | |
else: | |
self.logger.warning(f"LLM judged Q&A as BAD: {verdict}") | |
return False | |
def generate_question_and_answer( | |
self, | |
chunk: str | |
) -> Optional[Tuple[str, str]]: | |
""" | |
Generate a single (question, answer) pair from `chunk`. | |
Returns None if both rule-based and LLMΓÇÉbased checks fail after max_retries. | |
""" | |
lang = self.config.language | |
q_template = PROMPT_TEMPLATES[lang]["question"] | |
a_template = PROMPT_TEMPLATES[lang]["answer"] | |
for attempt in range(1, self.config.max_retries + 1): | |
self.logger.debug(f"Generating Q&A (attempt {attempt}/{self.config.max_retries})") | |
# 1) Produce question using generator client | |
q_prompt = q_template.format(context=chunk) | |
q_raw = self._call_llm_with_backoff( | |
client=self.gen_client, | |
model_name=self.config.gen_model_name, | |
prompt=q_prompt, | |
max_tokens=self.config.question_max_tokens, | |
temperature=0.7 | |
) | |
if not q_raw: | |
self.logger.error("Failed to get a question from generator LLM.") | |
continue | |
question = self._strip_surrounding_quotes(q_raw) | |
time.sleep(self.config.request_delay) | |
# 2) Produce answer using generator client | |
a_prompt = a_template.format(context=chunk, question=question) | |
a_raw = self._call_llm_with_backoff( | |
client=self.gen_client, | |
model_name=self.config.gen_model_name, | |
prompt=a_prompt, | |
max_tokens=self.config.answer_max_tokens, | |
temperature=0.3 | |
) | |
if not a_raw: | |
self.logger.error("Failed to get an answer from generator LLM.") | |
continue | |
answer = self._strip_surrounding_quotes(a_raw) | |
# 3) Rule-based validation | |
if not self._validate_qa_pair(question, answer, chunk): | |
self.logger.info("Rule-based validation failed; retrying.") | |
continue | |
# 4) LLM-based quality check using judge client | |
if self._llm_qa_quality_check(chunk, question, answer): | |
return question, answer | |
else: | |
self.logger.info("LLM-based judgment failed; retrying.") | |
self.logger.error("Exceeded max retries. No valid Q&A for this chunk.") | |
return None | |
def _export_results_for_file( | |
self, | |
file_name: str, | |
qa_list: List[Dict[str, Any]] | |
): | |
""" | |
Immediately append this fileΓÇÖs QA pairs to the cumulative CSV/JSON/XLSX. | |
Also write a per-file ΓÇ£.inprogressΓÇ¥ JSON so partial results are recoverable. | |
""" | |
out_dir = self.config.output_directory | |
base = self.config.output_filename | |
# 1) Write per-file .inprogress JSON (overwrite each time for this file) | |
inprog_path = out_dir / f"{file_name}.inprogress.json" | |
try: | |
with open(inprog_path, "w", encoding="utf-8") as f: | |
json.dump(qa_list, f, indent=2, ensure_ascii=False) | |
except Exception as e: | |
self.logger.warning(f"Could not write in-progress file {inprog_path}: {e}") | |
# 2) Append to CSV | |
csv_path = out_dir / f"{base}.csv" | |
df_new = pd.DataFrame(qa_list) | |
if csv_path.exists(): | |
try: | |
df_old = pd.read_csv(csv_path, encoding="utf-8") | |
df_combined = pd.concat([df_old, df_new], ignore_index=True) | |
df_combined.to_csv(csv_path, index=False, encoding="utf-8") | |
self.logger.info( | |
f"Appended {len(df_new)} rows to existing CSV: {csv_path} " | |
f"(now {len(df_combined)} total)." | |
) | |
except Exception as e: | |
self.logger.warning( | |
f"Failed to append to CSV {csv_path}. Overwriting. Error: {e}" | |
) | |
df_new.to_csv(csv_path, index=False, encoding="utf-8") | |
else: | |
df_new.to_csv(csv_path, index=False, encoding="utf-8") | |
self.logger.info(f"Wrote new CSV with {len(df_new)} rows: {csv_path}") | |
# 3) Append to JSON | |
json_path = out_dir / f"{base}.json" | |
final_data: List[Dict[str, Any]] = [] | |
if json_path.exists(): | |
try: | |
with open(json_path, "r", encoding="utf-8") as f: | |
existing = json.load(f) | |
if isinstance(existing, list): | |
final_data = existing | |
else: | |
self.logger.warning(f"Existing JSON not a list; reinitializing.") | |
final_data = [] | |
except Exception: | |
self.logger.warning(f"Could not read existing JSON; reinitializing.") | |
final_data = [] | |
final_data.extend(qa_list) | |
try: | |
with open(json_path, "w", encoding="utf-8") as f: | |
json.dump(final_data, f, indent=2, ensure_ascii=False) | |
self.logger.info(f"Wrote/Updated JSON: {json_path} ({len(final_data)} total).") | |
except Exception as e: | |
self.logger.error(f"Failed to write JSON {json_path}: {e}") | |
# 4) Write XLSX once per run (overwrite), not per file (to reduce overhead) | |
# We'll do that at the end of the entire run. | |
def _export_full_excel(self): | |
""" | |
After all files are processed, write an XLSX with all accumulated data. | |
""" | |
base = self.config.output_filename | |
xlsx_path = self.config.output_directory / f"{base}.xlsx" | |
json_path = self.config.output_directory / f"{base}.json" | |
if not json_path.exists(): | |
self.logger.info("No JSON file to convert to Excel.") | |
return | |
try: | |
with open(json_path, "r", encoding="utf-8") as f: | |
data = json.load(f) | |
if not isinstance(data, list) or not data: | |
self.logger.info("JSON is empty or not a list; skipping Excel export.") | |
return | |
df = pd.DataFrame(data) | |
df.to_excel(xlsx_path, index=False, engine="openpyxl") | |
self.logger.info(f"Exported full Excel: {xlsx_path} ({len(df)} rows).") | |
except ImportError: | |
self.logger.warning("openpyxl not installed; cannot write Excel.") | |
except Exception as e: | |
self.logger.error(f"Error writing Excel {xlsx_path}: {e}") | |
def process_files(self): | |
""" | |
Main loop: iterate over each .txt file, skip if already processed, | |
generate Q&A for each chunk, and write per-file outputs immediately. | |
""" | |
try: | |
all_files = self._get_text_files() | |
except FileNotFoundError as e: | |
self.logger.error(str(e)) | |
return | |
to_process = [f for f in all_files if str(f) not in self.processed_files] | |
self.logger.info( | |
f"Processing {len(to_process)} new files; {len(self.processed_files)} already done." | |
) | |
for file_path in to_process: | |
file_name = file_path.stem | |
self.logger.info(f"=== Processing file: {file_name}.txt ===") | |
per_file_list: List[Dict[str, Any]] = [] | |
try: | |
text = file_path.read_text(encoding="utf-8") | |
except Exception as e: | |
self.logger.error(f"Could not read {file_path}: {e}") | |
# Mark as processed so it doesnΓÇÖt block future runs | |
self.processed_files.add(str(file_path)) | |
self._save_cumulative_checkpoint() | |
continue | |
if not text.strip(): | |
self.logger.warning(f"File {file_name}.txt is empty; skipping.") | |
self.processed_files.add(str(file_path)) | |
self._save_cumulative_checkpoint() | |
continue | |
chunks = self._get_text_chunks(text) | |
if not chunks: | |
self.logger.warning( | |
f"No chunks >= {self.config.min_chunk_length_chars} chars in {file_name}.txt." | |
) | |
self.processed_files.add(str(file_path)) | |
self._save_cumulative_checkpoint() | |
continue | |
self.logger.info(f"File {file_name}.txt: {len(chunks)} chunks to process.") | |
# Use tqdm to show progress through chunks | |
for idx, chunk in enumerate(tqdm(chunks, desc=file_name, unit="chunk"), start=1): | |
self.logger.debug(f"Processing chunk {idx}/{len(chunks)} of {file_name}.txt") | |
qa = self.generate_question_and_answer(chunk) | |
if qa: | |
question, answer = qa | |
entry = { | |
"Context": chunk, | |
"Question": question, | |
"Answer": answer, | |
"Source_File": file_path.name, | |
"Chunk_ID": idx, | |
} | |
per_file_list.append(entry) | |
self.qa_data_current_run.append(entry) | |
else: | |
self.logger.error( | |
f"Failed to generate valid Q&A for chunk {idx} of {file_name}.txt" | |
) | |
# After all chunks, export results for this file immediately | |
if per_file_list: | |
self._export_results_for_file(file_name, per_file_list) | |
self.logger.info( | |
f"Finished {file_name}.txt → {len(per_file_list)} Q&A pairs written." | |
) | |
else: | |
self.logger.warning(f"No Q&A pairs generated from {file_name}.txt") | |
# Mark file as done & update cumulative checkpoint | |
self.processed_files.add(str(file_path)) | |
self._save_cumulative_checkpoint() | |
# All files processed; remove checkpoint if none remain | |
try: | |
remaining = [ | |
f for f in all_files if str(f) not in self.processed_files | |
] | |
if not remaining and self.ckpt_cumulative.exists(): | |
self.logger.info("All files processed; deleting cumulative checkpoint.") | |
self.ckpt_cumulative.unlink() | |
except Exception as e: | |
self.logger.warning(f"Error checking remaining files: {e}") | |
def export_summary(self): | |
""" | |
Print a run summary to STDOUT and write final Excel. | |
""" | |
total_new = len(self.qa_data_current_run) | |
self.logger.info(f"Run complete: generated {total_new} total Q&A pairs in this run.") | |
print("\n" + "="*50) | |
print("EXPORT SUMMARY (This Run)") | |
print("="*50) | |
print(f"Q&A pairs generated in this run: {total_new}") | |
if total_new > 0: | |
sample = self.qa_data_current_run[:2] | |
for entry in sample: | |
print("-"*50) | |
print(f"Source File: {entry['Source_File']} | Chunk: {entry['Chunk_ID']}") | |
print(f"Question: {entry['Question']}") | |
print(f"Answer: {entry['Answer'][:100]}...") | |
print(f"Output directory: {self.config.output_directory}") | |
print("="*50 + "\n") | |
# Write the full Excel now that all files are done | |
self._export_full_excel() | |
def run(self): | |
start = time.time() | |
self.logger.info("=== Starting Q&A generation ===") | |
try: | |
self.process_files() | |
self.export_summary() | |
except KeyboardInterrupt: | |
self.logger.info("Interrupted by user; checkpoint saved.") | |
except Exception as e: | |
self.logger.error("Fatal error in run():", exc_info=True) | |
finally: | |
duration = time.time() - start | |
self.logger.info(f"Total runtime: {duration:.2f}s") | |
# ----------------------------------------------------------------------------- | |
# 5. MAIN ENTRY POINT | |
# ----------------------------------------------------------------------------- | |
def main(): | |
cfg = parse_args() | |
# If generator API key is missing and no judge override, exit | |
if not cfg.gen_api_key: | |
print("Error: No generator API key provided. Either set OPENAI_GEN_API_KEY or pass --gen-api-key.", file=sys.stderr) | |
sys.exit(1) | |
# Create input/output directories if they don't exist | |
cfg.input_directory.mkdir(parents=True, exist_ok=True) | |
cfg.output_directory.mkdir(parents=True, exist_ok=True) | |
generator = QAGenerator(cfg) | |
generator.run() | |
if __name__ == "__main__": | |
main() | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment