Last active
November 18, 2024 16:41
-
-
Save milkey-mouse/67da93957d21515d8e3276e110867f14 to your computer and use it in GitHub Desktop.
Classify lines of a text file according to a category with GPT-4o(-mini)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3 | |
from base64 import b64encode | |
from hashlib import blake2b | |
from math import exp, log | |
from typing import Iterable, TextIO | |
import asyncio | |
import json | |
import os | |
import random | |
import sqlite3 | |
import struct | |
import sys | |
MAX_REQUESTS_PER_MINUTE = 4950 | |
def slugify(s: str) -> bytes: | |
if special := "".join(c for c in s if not c.isalnum()): | |
slug = "".join(c if c.isalnum() else "-" for c in s) | |
return slug.encode("utf-8") + b"+" + b64encode(special.encode("utf-8")) | |
else: | |
return s.encode("utf-8") | |
def logsumexp(logprobs: Iterable[float]) -> float: | |
logprobs = tuple(logprobs) | |
if not logprobs: | |
return float("-inf") | |
max_lp = max(logprobs) | |
return max_lp + log(sum(exp(lp - max_lp) for lp in logprobs)) | |
def print_usage(file: TextIO = sys.stdout): | |
print("usage: classify <category> <model>", file=file) | |
print("reads items from stdin, one per line", file=file) | |
async def _main(): | |
if "-h" in sys.argv[1:] or "--help" in sys.argv[1:]: | |
print_usage() | |
sys.exit(0) | |
elif len(sys.argv) not in range(2, 4): | |
print_usage(file=sys.stderr) | |
sys.exit(1) | |
category = sys.argv[1] | |
model = sys.argv[2] if len(sys.argv) > 2 else "gpt-4o-mini" | |
cache_dir = os.path.expanduser(b"~/.cache/classify") | |
os.makedirs(cache_dir, exist_ok=True) | |
cache_path = os.path.join(cache_dir, slugify(f"{category}-{model}")) + b".db" | |
with sqlite3.connect(cache_path, autocommit=True) as db: | |
db.executescript( | |
""" | |
PRAGMA synchronous = OFF; | |
PRAGMA journal_mode = WAL; | |
CREATE TABLE IF NOT EXISTS results ( | |
-- hash cached classification keys for confidentiality | |
-- (whole BLOB PRIMARY KEY instead of prefix is ~5x slower) | |
prefix INTEGER PRIMARY KEY, -- first 8 bytes of hash | |
suffix INTEGER NOT NULL, -- last 8 bytes of hash | |
p REAL | |
) WITHOUT ROWID; | |
""" | |
) | |
missing = [] | |
results = {} | |
for item in sys.stdin: | |
item = item.rstrip("\n") | |
if item in results: | |
continue | |
digest = blake2b(item.encode("utf-8"), digest_size=16).digest() | |
prefix, suffix = struct.unpack("<qq", digest) | |
if cached := db.execute( | |
"SELECT suffix, p FROM results WHERE prefix = ?", (prefix,) | |
).fetchone(): | |
row_suffix, p = cached | |
if row_suffix == suffix: # handle prefix collisions | |
results[item] = p | |
continue | |
results[item] = None # set item in results to avoid duplicates | |
missing.append((item, prefix, suffix)) | |
if missing: | |
from openai import AsyncOpenAI | |
from tqdm.asyncio import tqdm | |
random.shuffle(missing) # obfuscate original order of lines to OpenAI | |
with tqdm(total=len(missing), desc="classifying") as pbar: | |
keys_path = os.path.expanduser("~/.config/io.datasette.llm/keys.json") | |
with open(keys_path) as f: | |
client = AsyncOpenAI(api_key=json.load(f)["openai-personal"]) | |
async def classify(item: str, prefix: int, suffix: int): | |
response = await client.chat.completions.create( | |
messages=( | |
{ | |
"role": "user", | |
"content": ( | |
f"<item>{item}</item>\n" | |
f"<category>{category}</category>\n" | |
f"Does <item> match <category>? Output only YES or NO." | |
), | |
}, | |
), | |
model=model, | |
logit_bias={ | |
# for GPT-4o(-mini) tokenizer: | |
"31958": 100, # YES | |
"14695": 100, # NO | |
}, | |
logprobs=True, | |
max_completion_tokens=1, | |
temperature=0, | |
seed=0, | |
top_logprobs=20, | |
) | |
lps = response.choices[0].logprobs.content[0].top_logprobs | |
yes = logsumexp( | |
lp.logprob for lp in lps if "yes" in lp.token.lower() | |
) | |
no = logsumexp(lp.logprob for lp in lps if "no" in lp.token.lower()) | |
assert max(yes, no) > float("-inf") | |
# split calculation for better numerical precision | |
if no > yes: | |
# NO more likely than YES | |
p = 1 / (1 + exp(no - yes)) | |
else: | |
# YES more likely than NO | |
p = 1 - 1 / (1 + exp(yes - no)) | |
results[item] = p | |
db.execute( | |
"INSERT INTO results (prefix, suffix, p) VALUES (?, ?, ?)", | |
(prefix, suffix, p), | |
) | |
tasks = set() | |
for item, prefix, suffix in missing: | |
tasks.add(asyncio.create_task(classify(item, prefix, suffix))) | |
await asyncio.sleep(60 / MAX_REQUESTS_PER_MINUTE) | |
done, tasks = await asyncio.wait( | |
tasks, | |
timeout=0, | |
return_when=asyncio.FIRST_COMPLETED, | |
) | |
for task in done: | |
await task | |
pbar.update() | |
for task in asyncio.as_completed(tasks): | |
await task | |
pbar.update() | |
for item, p in results.items(): | |
print(f"{p:.16f}\t{item}") | |
db.execute("VACUUM") | |
def main(): | |
asyncio.run(_main()) | |
if __name__ == "__main__": | |
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
[build-system] | |
requires = ["setuptools"] | |
build-backend = "setuptools.build_meta" | |
[project] | |
name = "classify" | |
version = "0.1.0" | |
dependencies = [ | |
"openai>=1.54.4", | |
"tqdm>=4.67.0" | |
] | |
[project.scripts] | |
classify = "classify:main" |
pipx install git+https://gist.github.com/67da93957d21515d8e3276e110867f14
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
License: CC0-1.0