Last active
February 27, 2023 04:00
-
-
Save pszemraj/7fdcaca6c80f889e7ea92233d5aa7bee to your computer and use it in GitHub Desktop.
script to test summarization with the Cohere API
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
run_cohere_summarization.py - Summarize text files with Co.Summarize API as a python script | |
""" | |
import argparse | |
import json | |
import logging | |
import os | |
import pprint as pp | |
import random | |
import shutil | |
import sys | |
import time | |
from datetime import datetime | |
from pathlib import Path | |
import cohere | |
from tqdm import tqdm | |
def setup_logging(loglevel, logfile=None): | |
"""Setup basic logging | |
Args: | |
loglevel (int): minimum loglevel for emitting messages | |
logfile (str): path to logfile. If None, log to stderr. | |
""" | |
logformat = "[%(asctime)s] %(levelname)s:%(name)s:%(message)s" | |
# remove any existing handlers | |
root = logging.getLogger() | |
if root.handlers: | |
for handler in root.handlers: | |
root.removeHandler(handler) | |
if logfile is None: | |
# log to console | |
logging.basicConfig( | |
level=loglevel, | |
stream=sys.stdout, | |
format=logformat, | |
datefmt="%Y-%m-%d %H:%M:%S", | |
) | |
else: | |
# log to file | |
logging.basicConfig( | |
level=loglevel, | |
filename=logfile, | |
filemode="w", | |
format=logformat, | |
datefmt="%Y-%m-%d %H:%M:%S", | |
) | |
def get_parser(): | |
"""Get parser for command line arguments""" | |
parser = argparse.ArgumentParser( | |
description="Summarize text files with Co.Summarize API", | |
formatter_class=argparse.ArgumentDefaultsHelpFormatter, | |
) | |
parser.add_argument( | |
"input_dir", type=Path, help="Input directory containing files to summarize" | |
) | |
parser.add_argument( | |
"-o", | |
"--output-dir", | |
type=Path, | |
default=None, | |
help="Output directory to save summarized files", | |
) | |
parser.add_argument( | |
"-ext", | |
"--extension", | |
type=str, | |
default=".txt", | |
help="File extension to look for in input directory", | |
) | |
parser.add_argument( | |
"-m", | |
"--model", | |
type=str, | |
default="summarize-xlarge", | |
help="summarization model name for cohere.ai summarization", | |
) | |
parser.add_argument( | |
"--length", | |
type=str, | |
default="medium", | |
help="Length of summary options: short, medium, long", | |
) | |
parser.add_argument( | |
"--format", | |
type=str, | |
default="paragraph", | |
help="Format of summary (bullet or paragraph)", | |
) | |
parser.add_argument( | |
"-e", | |
"--extractiveness", | |
type=str, | |
default="low", | |
help="Extractiveness of summary (low, medium, high)", | |
) | |
parser.add_argument( | |
"-t", "--temperature", type=float, default=0.5, help="Temperature of summary" | |
) | |
parser.add_argument( | |
"--additional-command", | |
type=str, | |
default=None, | |
help="Additional command for summary", | |
) | |
parser.add_argument( | |
"--api-key", | |
type=str, | |
default=os.getenv("COHERE_API_KEY"), | |
help="cohere API key", | |
) | |
parser.add_argument( | |
"--dry-run", | |
action="store_true", | |
help="Test script without calling the API or saving any files", | |
) | |
parser.add_argument( | |
"-v", | |
"--verbose", | |
dest="loglevel", | |
help="Set loglevel to INFO", | |
action="store_const", | |
const=logging.INFO, | |
) | |
parser.add_argument( | |
"-vv", | |
"--very-verbose", | |
dest="loglevel", | |
help="Set loglevel to DEBUG", | |
action="store_const", | |
const=logging.DEBUG, | |
) | |
return parser | |
def summarize_file( | |
file_path, | |
client, | |
model: str = "summarize-xlarge", | |
length: str = "medium", | |
output_format: str = "paragraph", | |
extractiveness: str = "low", | |
temperature: float = 0.5, | |
additional_command: str = None, | |
batch_size: int = 48000, | |
): | |
""" | |
summarize_file - Summarize a text file using the Co.Summarize API | |
:param Path file_path: Path to file to summarize | |
:param cohere.Client client: Co.Summarize API client | |
:param str model: name of model to use for summarization, defaults to "summarize-xlarge" | |
:param str length: length of summary, defaults to "medium" | |
:param str output_format: format of summary, defaults to "paragraph" | |
:param str extractiveness: extractiveness of summary, defaults to "low" | |
:param float temperature: temperature parameter for summary, defaults to 0.5 | |
:param str additional_command: additional command for summary, defaults to None | |
:param int batch_size: batch size for summarization, defaults to 48000 characters | |
:return str: summary of text | |
""" | |
with open(file_path, "r", encoding="utf-8", errors="ignore") as f: | |
text = f.read() | |
logging.debug(f"Summarizing text with length {len(text)} characters") | |
if len(text) <= batch_size: | |
response = client.summarize( | |
text=text, | |
model=model, | |
length=length, | |
format=output_format, | |
extractiveness=extractiveness, | |
temperature=temperature, | |
additional_command=additional_command, | |
) | |
summary = response.summary | |
else: | |
batches = [text[i : i + batch_size] for i in range(0, len(text), batch_size)] | |
summaries = [] | |
for batch in tqdm(batches, desc="batch summaries", total=len(batches)): | |
response = client.summarize( | |
text=batch, | |
model=model, | |
length=length, | |
format=output_format, | |
extractiveness=extractiveness, | |
temperature=temperature, | |
additional_command=additional_command, | |
) | |
summaries.append(response.summary) | |
summary = "\n".join(summaries) | |
logging.debug(f"Summarized file {file_path}") | |
logging.debug(f"API response: {response}") | |
return summary | |
def summarize_files( | |
input_dir: Path, | |
output_dir: Path = None, | |
extension: str = ".txt", | |
model: str = "summarize-xlarge", | |
length: str = "medium", | |
output_format: str = "paragraph", | |
extractiveness: str = "low", | |
temperature: float = 0.5, | |
additional_command: str = None, | |
api_key: str = None, | |
dry_run: bool = False, | |
max_wait: int = 15, | |
): | |
""" | |
summarize_files - summarize text files with Co.Summarize API from cohere | |
:param Path input_dir: input directory | |
:param Path output_dir: output directory, defaults to None | |
:param str extension: file extension to summarize, defaults to ".txt" | |
:param str model: Co.Summarize model, defaults to "summarize-xlarge" | |
:param str length: Length of summary options: short, medium, long, defaults to "medium" | |
:param str format: Format of summary options: paragraph, bullet, defaults to "paragraph" | |
:param str extractiveness: Extractiveness of summary (low, medium, high), defaults to "low" | |
:param float temperature: Temperature of summary, defaults to 0.5 | |
:param str additional_command: Additional command for summary, defaults to None | |
:param str api_key: cohere API key, defaults to None and will error if not set | |
:param bool dry_run: Test script without calling the API or saving any files, defaults to False | |
:param int max_wait: Maximum number of seconds to wait for API response, defaults to 15 | |
:return: Path to output directory | |
""" | |
assert ( | |
input_dir.is_dir() and input_dir.exists() | |
), f"Input directory {input_dir} does not exist" | |
output_dir = ( | |
input_dir.parent / "summarize-cohere" if output_dir is None else output_dir | |
) | |
output_dir.mkdir(parents=True, exist_ok=True) | |
if not api_key: | |
logging.error( | |
"No Co.Here API key specified. Set it using the --api-key command line argument or the COHERE_API_KEY environment variable." | |
) | |
sys.exit(1) | |
client = cohere.Client(api_key) | |
files = [f for f in input_dir.glob(f"*{extension}") if f.is_file() and f.exists()] | |
logging.info(f"summarizing {len(files)} files in:\n{str(input_dir)}") | |
for file_path in tqdm(files, desc="Summarizing files", total=len(files)): | |
logging.info(f"Summarizing file {file_path.name}") | |
time.sleep(random.randint(1, max_wait)) | |
if dry_run: | |
continue | |
try: | |
summary = summarize_file( | |
file_path, | |
client, | |
model, | |
length, | |
output_format, | |
extractiveness, | |
temperature, | |
additional_command, | |
) | |
summary_path = output_dir / f"{file_path.stem}_summary{extension}" | |
with open(summary_path, "w", encoding="utf-8", errors="ignore") as f: | |
f.write(summary) | |
logging.info(f"Completed file {file_path.name}") | |
except Exception as e: | |
logging.error(f"Error summarizing file {file_path.name}: {e}") | |
continue | |
summary_params = { | |
"input_dir": str(input_dir), | |
"run_date": datetime.now().strftime("%Y-%b-%d %H:%M"), | |
"extension": extension, | |
"model": model, | |
"length": length, | |
"format": output_format, | |
"extractiveness": extractiveness, | |
"temperature": temperature, | |
"additional_command": additional_command, | |
} | |
if not dry_run: | |
_params_path = output_dir / "params.json" | |
with open(_params_path, "w") as f: | |
json.dump(summary_params, f, indent=4) | |
logging.info(f"Saved summary parameters to {_params_path}") | |
logging.info(f"Completed summarization run") | |
return output_dir | |
def main(): | |
"""Main function for summarize-cohere script""" | |
args = get_parser().parse_args() | |
_logfile_path = Path.cwd() / "LOGFILE_summarize-cohere.log" | |
setup_logging(args.loglevel, logfile=_logfile_path) | |
_args = dict(args.__dict__) | |
_args.pop("api_key") # remove api_key from args | |
logging.info(f"starting new summarization run with args:\n{pp.pformat(_args)}") | |
input_dir = Path(args.input_dir).resolve() | |
assert input_dir.exists(), f"Input directory {input_dir} does not exist" | |
output_dir = ( | |
Path(args.output_dir) | |
if args.output_dir is not None | |
else input_dir.parent / "summarize-cohere" | |
) | |
output_dir.mkdir(parents=True, exist_ok=True) | |
logging.info( | |
f"Summarizing files in {str(input_dir.resolve())} and saving to:\n\t{str(output_dir.resolve())}" | |
) | |
output_dir = summarize_files( | |
input_dir=input_dir, | |
output_dir=output_dir, | |
extension=args.extension, | |
model=args.model, | |
length=args.length, | |
output_format=args.format, | |
extractiveness=args.extractiveness, | |
temperature=args.temperature, | |
additional_command=args.additional_command, | |
api_key=args.api_key, | |
dry_run=args.dry_run, | |
) | |
shutil.copy(_logfile_path, output_dir / _logfile_path.name) | |
print(f"Done! Files saved to:\n\t{str(output_dir.resolve())}") | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment