Skip to content

Instantly share code, notes, and snippets.

@milkey-mouse
Last active November 18, 2024 16:41
Show Gist options
  • Save milkey-mouse/67da93957d21515d8e3276e110867f14 to your computer and use it in GitHub Desktop.
Save milkey-mouse/67da93957d21515d8e3276e110867f14 to your computer and use it in GitHub Desktop.
Classify lines of a text file according to a category with GPT-4o(-mini)
#!/usr/bin/env python3
from base64 import b64encode
from hashlib import blake2b
from math import exp, log
from typing import Iterable, TextIO
import asyncio
import json
import os
import random
import sqlite3
import struct
import sys
MAX_REQUESTS_PER_MINUTE = 4950
def slugify(s: str) -> bytes:
if special := "".join(c for c in s if not c.isalnum()):
slug = "".join(c if c.isalnum() else "-" for c in s)
return slug.encode("utf-8") + b"+" + b64encode(special.encode("utf-8"))
else:
return s.encode("utf-8")
def logsumexp(logprobs: Iterable[float]) -> float:
logprobs = tuple(logprobs)
if not logprobs:
return float("-inf")
max_lp = max(logprobs)
return max_lp + log(sum(exp(lp - max_lp) for lp in logprobs))
def print_usage(file: TextIO = sys.stdout):
print("usage: classify <category> <model>", file=file)
print("reads items from stdin, one per line", file=file)
async def _main():
if "-h" in sys.argv[1:] or "--help" in sys.argv[1:]:
print_usage()
sys.exit(0)
elif len(sys.argv) not in range(2, 4):
print_usage(file=sys.stderr)
sys.exit(1)
category = sys.argv[1]
model = sys.argv[2] if len(sys.argv) > 2 else "gpt-4o-mini"
cache_dir = os.path.expanduser(b"~/.cache/classify")
os.makedirs(cache_dir, exist_ok=True)
cache_path = os.path.join(cache_dir, slugify(f"{category}-{model}")) + b".db"
with sqlite3.connect(cache_path, autocommit=True) as db:
db.executescript(
"""
PRAGMA synchronous = OFF;
PRAGMA journal_mode = WAL;
CREATE TABLE IF NOT EXISTS results (
-- hash cached classification keys for confidentiality
-- (whole BLOB PRIMARY KEY instead of prefix is ~5x slower)
prefix INTEGER PRIMARY KEY, -- first 8 bytes of hash
suffix INTEGER NOT NULL, -- last 8 bytes of hash
p REAL
) WITHOUT ROWID;
"""
)
missing = []
results = {}
for item in sys.stdin:
item = item.rstrip("\n")
if item in results:
continue
digest = blake2b(item.encode("utf-8"), digest_size=16).digest()
prefix, suffix = struct.unpack("<qq", digest)
if cached := db.execute(
"SELECT suffix, p FROM results WHERE prefix = ?", (prefix,)
).fetchone():
row_suffix, p = cached
if row_suffix == suffix: # handle prefix collisions
results[item] = p
continue
results[item] = None # set item in results to avoid duplicates
missing.append((item, prefix, suffix))
if missing:
from openai import AsyncOpenAI
from tqdm.asyncio import tqdm
random.shuffle(missing) # obfuscate original order of lines to OpenAI
with tqdm(total=len(missing), desc="classifying") as pbar:
keys_path = os.path.expanduser("~/.config/io.datasette.llm/keys.json")
with open(keys_path) as f:
client = AsyncOpenAI(api_key=json.load(f)["openai-personal"])
async def classify(item: str, prefix: int, suffix: int):
response = await client.chat.completions.create(
messages=(
{
"role": "user",
"content": (
f"<item>{item}</item>\n"
f"<category>{category}</category>\n"
f"Does <item> match <category>? Output only YES or NO."
),
},
),
model=model,
logit_bias={
# for GPT-4o(-mini) tokenizer:
"31958": 100, # YES
"14695": 100, # NO
},
logprobs=True,
max_completion_tokens=1,
temperature=0,
seed=0,
top_logprobs=20,
)
lps = response.choices[0].logprobs.content[0].top_logprobs
yes = logsumexp(
lp.logprob for lp in lps if "yes" in lp.token.lower()
)
no = logsumexp(lp.logprob for lp in lps if "no" in lp.token.lower())
assert max(yes, no) > float("-inf")
# split calculation for better numerical precision
if no > yes:
# NO more likely than YES
p = 1 / (1 + exp(no - yes))
else:
# YES more likely than NO
p = 1 - 1 / (1 + exp(yes - no))
results[item] = p
db.execute(
"INSERT INTO results (prefix, suffix, p) VALUES (?, ?, ?)",
(prefix, suffix, p),
)
tasks = set()
for item, prefix, suffix in missing:
tasks.add(asyncio.create_task(classify(item, prefix, suffix)))
await asyncio.sleep(60 / MAX_REQUESTS_PER_MINUTE)
done, tasks = await asyncio.wait(
tasks,
timeout=0,
return_when=asyncio.FIRST_COMPLETED,
)
for task in done:
await task
pbar.update()
for task in asyncio.as_completed(tasks):
await task
pbar.update()
for item, p in results.items():
print(f"{p:.16f}\t{item}")
db.execute("VACUUM")
def main():
asyncio.run(_main())
if __name__ == "__main__":
main()
[build-system]
requires = ["setuptools"]
build-backend = "setuptools.build_meta"
[project]
name = "classify"
version = "0.1.0"
dependencies = [
"openai>=1.54.4",
"tqdm>=4.67.0"
]
[project.scripts]
classify = "classify:main"
@milkey-mouse
Copy link
Author

License: CC0-1.0

@milkey-mouse
Copy link
Author

pipx install git+https://gist.github.com/67da93957d21515d8e3276e110867f14

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment