Created
March 24, 2023 22:36
-
-
Save guyromm/959656e43bdb5d29daf1b561b26364aa to your computer and use it in GitHub Desktop.
gpt4 powered long document summarizer
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python | |
""" | |
invocation example: | |
F=2303.12712.pdf ; curl 'https://arxiv.org/pdf/'$F -o $F && pdf2txt $F | ./summarize.py $F-gpt4.json | |
""" | |
import sys,json | |
from transformers import GPT2TokenizerFast | |
import asyncio | |
from chatgpt_wrapper.openai.api import AsyncOpenAIAPI | |
from chatgpt_wrapper.openai.api_shell import ApiShell | |
from chatgpt_wrapper.config import Config | |
MAX_TOKENS=4096 | |
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2") | |
lines=[] | |
summaries=[] | |
def chunker(lines,fr=0): | |
#print('chunker fr=',fr) | |
tosum=[] ; i=1 | |
pt=0 | |
while True: | |
if len(summaries): | |
prefix = ['Summaries of previous content, separated by "---", followed by a line with "===", after which there is unsummarized that text you should summarize. Be terse and avoid terms such as "the text" or "the paper".']+["\n---\n".join(summaries)]+["==="] | |
else: | |
prefix = ['Summarize the following text. Be terse and avoid terms such as "the text" or "the paper"'] | |
tosum=prefix+lines[fr:i] | |
tokens = len(tokenizer("\n".join(tosum))['input_ids']) | |
if tokens<MAX_TOKENS: | |
if i>=len(lines): | |
break | |
else: | |
i+=1 | |
pt = tokens | |
else: | |
break | |
#print('tokenizing',fr,':',i,'/',len(lines),'lines results in',pt) | |
return prefix,tosum,fr,i | |
def make_chunks(): | |
chunks=[] | |
for ln in sys.stdin: | |
lines.append(ln) | |
fr=0 ; i=1 | |
while fr<len(lines): | |
prefix,chunk,fr,i = chunker(lines,fr=fr) | |
#print('yielding') | |
yield {'text':"\n".join(chunk),"prefix":"\n".join(prefix),'fr':fr,'i':i} | |
_=fr ; fr=i ; i=_+i | |
async def main(): | |
config = Config() | |
config.set('chat.model', 'gpt4') | |
gpt = AsyncOpenAIAPI(config) | |
shell = ApiShell(config) | |
err=None | |
input_chunks=[] | |
def save(): | |
fp = open(sys.argv[1],'w') | |
fp.write(json.dumps({"summaries":summaries, | |
"input_chunks":input_chunks, | |
"err":err})) | |
fp.close() | |
try: | |
for chunk in make_chunks(): | |
input_chunks.append(chunk) | |
shell._print_markdown(f"# summarizing lines: {chunk['fr']}:{chunk['i']}/{len(lines)}") | |
first = True | |
gpt.set_model_temperature(0.0) | |
summary=[] | |
async for chunk in gpt.ask_stream(chunk['text']): | |
if first: | |
print("") | |
first = False | |
print(chunk, end="") | |
summary.append(chunk) | |
sys.stdout.flush() | |
summaries.append("".join(summary)) | |
input_chunks[-1]['output']="".join(summary) | |
print("\n") | |
# Work around rate limit if needed. | |
save() | |
await asyncio.sleep(5) | |
except Exception as e: | |
err=str(e) | |
finally: | |
save() | |
asyncio.run(main()) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment