Created
April 29, 2025 06:48
-
-
Save ondrejsojka/1dbd3436e63fbee7e371cc201164cb94 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# /// script | |
# requires-python = ">=3.12" | |
# dependencies = [ | |
# "aiofiles", | |
# "aiohttp", | |
# "tiktoken", | |
# "openai>=1.75.0", # for openai.AsyncOpenAI | |
# "pathspec", | |
# ] | |
# /// | |
import asyncio | |
import aiofiles | |
import argparse | |
import os | |
import re | |
import sys | |
from pathlib import Path | |
import time | |
import tiktoken | |
import pathspec | |
from typing import Optional, Tuple, List, Dict | |
import concurrent.futures | |
import openai | |
# CONFIGURATION CONSTANTS | |
DEFAULT_MODEL = "gpt-4.1-nano" | |
MAX_CONCURRENT_WORKERS = 80 # stay well under RPM | |
MAX_RETRIES = 5 | |
INITIAL_BACKOFF = 1.0 | |
MAX_BACKOFF = 30 | |
USE_ANTHROPIC_COUNT = False # Opt-in | |
PROGRESS_REPORT_INTERVAL = 2.0 # seconds | |
USER_CONFIRM_TOKEN_THRESHOLD = 1_000_000 | |
tokenizer = tiktoken.get_encoding("o200k_base") | |
def count_tokens(text: str, tokenizer) -> int: | |
return len(tokenizer.encode(text)) | |
# ------- FILE UTIL & GITIGNORE ------- | |
def load_gitignore(root_dir: Path): | |
gitignore = root_dir / '.gitignore' | |
try: | |
with gitignore.open('r') as file: | |
spec = pathspec.PathSpec.from_lines('gitwildmatch', file) | |
except IOError: | |
print(f"Warning: Unable to read .gitignore in {root_dir}, ignoring .gitignore rules.") | |
spec = pathspec.PathSpec.from_lines('gitwildmatch', []) | |
return spec | |
def file_should_be_considered(p: Path, suffixes, prefixes, ignore_spec) -> bool: | |
fname = p.name | |
# .gitignored? | |
if ignore_spec.match_file(str(p.relative_to(p.anchor))): | |
return False | |
if suffixes and not any(fname.endswith(s) for s in suffixes): | |
return False | |
if prefixes and not any(fname.startswith(s) for s in prefixes): | |
return False | |
if p.name.startswith('.'): | |
return False | |
return p.is_file() | |
# ------- ASYNC TOKEN COUNT ------- | |
async def compute_total_tokens(files: List[Path], tokenizer): | |
async def read_and_count(p: Path): | |
try: | |
async with aiofiles.open(p, 'r') as f: | |
content = await f.read() | |
return count_tokens(content, tokenizer) | |
except Exception: | |
return 0 | |
tasks = [read_and_count(p) for p in files] | |
tokens = await asyncio.gather(*tasks) | |
return sum(tokens) | |
# ------- PRODUCER ------- | |
async def producer(file_queue: asyncio.Queue, root_dir: Path, suffixes, prefixes, ignore_spec): | |
for dirpath, dirnames, filenames in os.walk(root_dir): | |
dirnames[:] = [d for d in dirnames if not d.startswith('.git')] | |
for file in filenames: | |
path = Path(dirpath) / file | |
if file_should_be_considered(path, suffixes, prefixes, ignore_spec): | |
try: | |
async with aiofiles.open(path, 'r') as f: | |
content = await f.read() | |
await file_queue.put((str(path.relative_to(root_dir)), content)) | |
except Exception as e: | |
print(f"Warning: failed to read {path}: {e}") | |
# ------- INCLUSION CHECKER (WORKER) ------- | |
async def evaluate_inclusion(content: str, inclusion_prompt: str, semaphore, client, model, retries=MAX_RETRIES): | |
backoff = INITIAL_BACKOFF | |
for k in range(retries): | |
async with semaphore: | |
try: | |
resp = await client.chat.completions.create( | |
model=model, | |
messages=[ | |
{"role": "system", "content": "You are a classifier and decide whether a given file should be included or excluded based on the following instructions. Respond ONLY with y or n. y means include, n means exclude."}, | |
{"role": "system", "content": inclusion_prompt}, | |
{"role": "user", "content": content} | |
], | |
max_tokens=1, | |
temperature=0.1, | |
timeout=30, | |
) | |
result = resp.choices[0].message.content.lower() | |
if result.strip() == 'y' or result.strip() == ' y': | |
return True, 1.0 # score=1.0 (for extensibility) | |
elif result.strip() == 'n' or result.strip() == ' n': | |
return False, 0.0 | |
else: | |
print(f"Unexpected OpenAI classifier response: {result}") | |
return True, 0.5 | |
except openai.RateLimitError as e: | |
print(f"Hit OpenAI rate limit: {e}. Backing off {backoff:.1f}s") | |
await asyncio.sleep(backoff) | |
backoff = min(backoff * 2, MAX_BACKOFF) | |
except Exception as e: | |
print(f"OpenAI inclusion check error: {e}. Backing off {backoff:.1f}s") | |
await asyncio.sleep(backoff) | |
backoff = min(backoff * 2, MAX_BACKOFF) | |
return True, 0.0 # Default to include so worst-case is over-inclusion, not loss | |
async def consumer_worker(file_queue: asyncio.Queue, output_queue: asyncio.Queue, inclusion_prompt, semaphore, client, model, progress_state: dict): | |
while True: | |
try: | |
file_path, content = await file_queue.get() | |
except Exception: | |
break # Queue closed | |
try: | |
included, score = await evaluate_inclusion(content, inclusion_prompt, semaphore, client, model) | |
progress_state['checked'] += 1 | |
if included: | |
await output_queue.put((file_path, content, score)) | |
finally: | |
file_queue.task_done() | |
# ------- PROGRESS REPORTER ------- | |
async def progress_reporter(total_files, progress_state, file_queue, output_queue): | |
while not progress_state.get('done', False): | |
in_queue = file_queue.qsize() | |
out_queue = output_queue.qsize() | |
checked = progress_state.get('checked', 0) | |
print(f"Progress: {checked}/{total_files} checked | input queue: {in_queue} | output queue: {out_queue}") | |
await asyncio.sleep(PROGRESS_REPORT_INTERVAL) | |
# ------- WRITER: Collate and write outputs in order ------- | |
async def writer(output_queue: asyncio.Queue, output_file_path: Path, expected_count, root_dir: Path, tokenizer): | |
records = [] | |
for _ in range(expected_count): | |
file_path, content, score = await output_queue.get() | |
records.append((file_path, content, score)) | |
output_queue.task_done() | |
# sort and write | |
records.sort(key=lambda r: r[0]) | |
tmp_out = output_file_path.with_suffix(output_file_path.suffix + '.tmp') | |
total_output_tokens = 0 | |
async with aiofiles.open(tmp_out, 'w') as out: | |
for file_path, content, score in records: | |
header = f'---\n{file_path}\n' | |
await out.write(header) | |
await out.write(content) | |
await out.write('\n') | |
total_output_tokens += count_tokens(header + content + '\n', tokenizer) # Count tokens for header, content, and newline | |
tmp_out.rename(output_file_path) | |
print(f"Wrote output to {output_file_path}") | |
print(f"Total tokens in output file: {total_output_tokens}") | |
# ------- MAIN -------- | |
async def main(): | |
parser = argparse.ArgumentParser(description="Concatenate files in a directory (async filter via GPT), with robust throttling and interactive token awareness.") | |
parser.add_argument('root_dir', type=str, help='Root directory to search') | |
parser.add_argument('output_file', type=str, help='Output file path') | |
parser.add_argument('--filetype', nargs='+', default=[], help='File extensions (e.g. .py .js)') | |
parser.add_argument('--startswith', nargs='+', default=[], help='Filter: file name prefixes') | |
parser.add_argument('--inclusion_prompt', type=str, required=True, help='Prompt for GPT-4.1-nano to decide inclusion') | |
parser.add_argument('--max_workers', type=int, default=MAX_CONCURRENT_WORKERS, help='Max concurrent workers (default=80)') | |
parser.add_argument('--model', type=str, default=DEFAULT_MODEL, help='OpenAI model (default=gpt-4.1-nano)') | |
args = parser.parse_args() | |
model = args.model | |
tokenizer = tiktoken.get_encoding("o200k_base") | |
root_dir = Path(args.root_dir) | |
suffixes = args.filetype | |
prefixes = args.startswith | |
output_file_path = Path(args.output_file) | |
inclusion_prompt = args.inclusion_prompt | |
max_workers = args.max_workers | |
print("Loading .gitignore...") | |
ignore_spec = load_gitignore(root_dir) | |
print("Scanning and filtering files...") | |
all_files = [] | |
for dirpath, dirnames, filenames in os.walk(root_dir): | |
dirnames[:] = [d for d in dirnames if not d.startswith('.git')] | |
for file in filenames: | |
path = Path(dirpath) / file | |
if file_should_be_considered(path, suffixes, prefixes, ignore_spec): | |
all_files.append(path) | |
print(f"Candidate files: {len(all_files)}") | |
print("Precomputing total token count...") | |
token_total = await compute_total_tokens(all_files, tokenizer) | |
print(f"Total tokens in candidate files: {token_total}") | |
if token_total > USER_CONFIRM_TOKEN_THRESHOLD: | |
yn = input(f"WARNING: {token_total} tokens > {USER_CONFIRM_TOKEN_THRESHOLD} threshold. Continue? [y/N]: ") | |
if yn.lower() not in ("y", "yes"): | |
print("Aborted.") | |
return | |
file_queue = asyncio.Queue() | |
output_queue = asyncio.Queue() | |
progress_state = {'checked': 0, 'done': False} | |
print("Queueing file contents...") | |
await producer(file_queue, root_dir, suffixes, prefixes, ignore_spec) | |
total_files = file_queue.qsize() | |
print("Launching classification workers...") | |
openai_client = openai.AsyncOpenAI() | |
semaphore = asyncio.Semaphore(max_workers) | |
workers = [asyncio.create_task(consumer_worker( | |
file_queue, | |
output_queue, | |
inclusion_prompt, | |
semaphore, | |
openai_client, | |
model, | |
progress_state, | |
)) for _ in range(max_workers)] | |
progress_task = asyncio.create_task(progress_reporter(total_files, progress_state, file_queue, output_queue)) | |
# Let file_queue drain | |
await file_queue.join() | |
progress_state['done'] = True | |
await asyncio.sleep(0.2) | |
expected = output_queue.qsize() | |
await writer(output_queue, output_file_path, expected, root_dir, tokenizer) | |
for w in workers: | |
w.cancel() | |
await asyncio.sleep(0.1) | |
progress_task.cancel() | |
if __name__ == "__main__": | |
asyncio.run(main()) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment