Last active
August 19, 2023 09:51
-
-
Save johnwheeler/ba661c122543d06a5ca6326ffc8c734a to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python | |
""" | |
pip install tiktoken requests | |
export OPENAI_API_KEY=<redact> | |
chmod +x prepare-commit-msg | |
mv prepare-commit-msg .git/hooks/ | |
""" | |
import math | |
import os | |
import subprocess | |
import sys | |
import requests | |
import tiktoken | |
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") | |
GPT_MODEL = "gpt-3.5-turbo-16k" | |
MAX_TOKENS = 16384 | |
def run_git_diff(): | |
command = ["git", "diff", "HEAD"] | |
try: | |
result = subprocess.run(command, capture_output=True, text=True, check=True) | |
except subprocess.CalledProcessError as e: | |
print(f"Error occurred: {e.output}") | |
return None | |
return result.stdout | |
def num_tokens_from_string(string: str, encoding_name=GPT_MODEL) -> int: | |
"""Returns the number of tokens in a text string.""" | |
encoding = tiktoken.encoding_for_model(encoding_name) | |
num_tokens = len(encoding.encode(string)) | |
return num_tokens | |
def summarize_changes(text): | |
num_tokens = num_tokens_from_string(text) | |
if num_tokens > MAX_TOKENS: | |
print("too many tokens") | |
return | |
num_tokens_rounded = int(math.ceil(num_tokens / 100.0)) * 100 | |
num_tokens_left = MAX_TOKENS - num_tokens_rounded | |
endpoint = "https://api.openai.com/v1/chat/completions" | |
headers = {"Content-Type": "application/json", "Authorization": f"Bearer {OPENAI_API_KEY}"} | |
payload = { | |
"model": GPT_MODEL, | |
"messages": [ | |
{ | |
"role": "system", | |
"content": "Summarize into a commit message followed by bullets.", | |
}, | |
{"role": "user", "content": text}, | |
], | |
"max_tokens": num_tokens_left, | |
} | |
# Count the number of tokens in the text | |
response = requests.post(endpoint, headers=headers, json=payload) | |
if response.status_code != 200: | |
print(f"Error occurred: {response.json().get('error', {}).get('message', 'Unknown error')}") | |
return None | |
data = response.json() | |
return data["choices"][0]["message"]["content"] | |
def main(): | |
commit_msg_filepath = sys.argv[1] | |
diff_output = run_git_diff() | |
if diff_output: | |
print("Working...") | |
summary = summarize_changes(diff_output) | |
if summary: | |
with open(commit_msg_filepath, "w") as file: | |
file.write(summary) | |
else: | |
print("No output from git diff or an error occurred.") | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment