Created
September 17, 2024 13:38
-
-
Save danvk/25a87ee22af73cdd32d2d707a42561b7 to your computer and use it in GitHub Desktop.
Python Type Hint Autocomplete repro
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python | |
"""Manage multiple batches on OpenAI""" | |
import json | |
import hashlib | |
import os | |
import sys | |
import time | |
from typing import Optional, TypedDict, get_type_hints | |
from dotenv import load_dotenv | |
import openai | |
def sha256sum(filename: str): | |
with open(filename, "rb", buffering=0) as f: | |
return hashlib.file_digest(f, "sha256").hexdigest() | |
STATUS_FILE = "/tmp/batch-status.json" | |
# https://stackoverflow.com/q/78991776/388951 | |
type_hints = get_type_hints(openai.types.Batch) | |
BatchStatus = type_hints["status"] | |
class FileStatus(TypedDict): | |
filename: str | |
sha256: str | |
file_id: Optional[str] | |
batch_id: Optional[str] | |
batch_status: Optional[BatchStatus] | |
batch: dict # TODO: any way to say "JSON version of openai.types.Batch?" | |
output_file_sha256: Optional[str] | |
def load_status(sha256: str) -> FileStatus | None: | |
if not os.path.exists(STATUS_FILE): | |
return None | |
with open(STATUS_FILE) as f: | |
statuses = json.load(f) | |
return statuses.get(sha256) | |
def dump_status(sha256: str, status: FileStatus): | |
cache = {} | |
if os.path.exists(STATUS_FILE): | |
with open(STATUS_FILE) as f: | |
cache = json.load(f) | |
cache[sha256] = status | |
with open(STATUS_FILE, "w") as f: | |
json.dump(cache, f) | |
def is_failed_batch(status: BatchStatus): | |
return ( | |
status == "cancelled" | |
or status == "cancelling" | |
or status == "expired" | |
or status == "failed" | |
) | |
def is_in_progress_batch(status: BatchStatus): | |
return status == "finalizing" or status == "in_progress" or status == "validating" | |
def is_done_batch(status: BatchStatus): | |
return status == "completed" | |
if __name__ == "__main__": | |
load_dotenv() | |
batch_files = sys.argv[1:] | |
client = openai.OpenAI() | |
# For each file: | |
# - Check if we've already uploaded it to OpenAI | |
# - Upload it if necessary, save file ID to cache | |
# - Create the batch (if no successful or pending batch for this file ID) | |
# - Monitor the batch status | |
# - Retrieve the batch output when done | |
for batch_file in batch_files: | |
assert batch_file.endswith(".jsonl") | |
sha256 = sha256sum(batch_file) | |
print(f"{batch_file}: {sha256}") | |
status = load_status(sha256) | |
if not status: | |
status: FileStatus = {"filename": batch_file, "sha256": sha256} | |
dump_status(sha256, status) | |
file_id = status.get("file_id") | |
if not file_id: | |
print(f"{batch_file}: uploading to OpenAI") | |
f = client.files.create(file=open(batch_file, "rb"), purpose="batch") | |
file_id = f.id | |
status["file_id"] = file_id | |
dump_status(sha256, status) | |
print(f"{batch_file}: {file_id=}") | |
else: | |
print(f"{batch_file}: using previously-uploaded {file_id=}") | |
batch_id = status.get("batch_id") | |
if not batch_id: | |
print(f"{batch_file}: creating batch") | |
r = client.batches.create( | |
input_file_id=file_id, | |
completion_window="24h", | |
endpoint="/v1/chat/completions", | |
metadata={ | |
"filename": batch_file, | |
"sha256": sha256, | |
}, | |
) | |
batch_id = r.id | |
batch_status = r.status | |
status["batch_id"] = batch_id | |
status["batch_status"] = batch_status | |
print(f"{batch_file}: created batch {batch_id=}, {batch_status=}") | |
dump_status(sha256, status) | |
else: | |
print(f"{batch_file}: using previously-created {batch_id=}") | |
did_fetch_status = False | |
def fetch_batch_status(): | |
global batch_status, did_fetch_status | |
did_fetch_status = True | |
r = client.batches.retrieve(batch_id=batch_id) | |
batch_status = r.status | |
status["batch"] = r.to_dict(mode="json") | |
status["batch_status"] = batch_status | |
dump_status(sha256, status) | |
counts = status["batch"].get("request_counts") | |
msg = "" | |
if counts: | |
completed = counts.get("completed", 0) | |
total = counts.get("total") | |
if total > 0: | |
msg = f"{completed} / {total} complete " | |
print( | |
f"{batch_file}: retrieved batch status {batch_status} {msg}{batch_id=}" | |
) | |
batch_status = status.get("batch_status") | |
if not batch_status: | |
fetch_batch_status() | |
def check_fail_status(): | |
if is_failed_batch(batch_status): | |
sys.stderr.write( | |
f"{batch_file}: {batch_id} batch failed: {batch_status}\n" | |
) | |
sys.stderr.write( | |
f"Check https://platform.openai.com/batches/{batch_id}\n" | |
) | |
sys.exit(1) | |
check_fail_status() | |
while is_in_progress_batch(batch_status): | |
if did_fetch_status: | |
time.sleep(5) | |
# TODO: reduce repetition | |
fetch_batch_status() | |
if is_done_batch(batch_status) or is_failed_batch(batch_status): | |
break | |
check_fail_status() | |
# batch must be done! | |
r = status["batch"] | |
assert r is not None | |
out_file_id = r["output_file_id"] | |
out_path = batch_file.replace(".jsonl", ".output.jsonl") | |
output_file_sha256 = status.get("output_file_sha256") | |
if ( | |
output_file_sha256 | |
and os.path.exists(out_path) | |
and sha256sum(out_path) == output_file_sha256 | |
): | |
print(f"{batch_file}: output already exists at {out_path} and matches SHA") | |
else: | |
counts = r["request_counts"] | |
completed = counts["completed"] | |
failed = counts["failed"] | |
created_at_s = r["created_at"] | |
completed_at_s = r["completed_at"] | |
elapsed_s = completed_at_s - created_at_s | |
print( | |
f"{batch_file}: {completed} completed / {failed} failed in {elapsed_s:.0f}s" | |
) | |
r = client.files.content(file_id=out_file_id) | |
with open(out_path, "wb") as out: | |
out.write(r.content) | |
output_file_sha256 = sha256sum(out_path) | |
status["output_file_sha256"] = output_file_sha256 | |
dump_status(sha256, status) | |
print(f"{batch_file}: downloaded output to {out_path}") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment