Skip to content

Instantly share code, notes, and snippets.

@pszemraj
Last active February 27, 2023 04:00
Show Gist options
  • Save pszemraj/7fdcaca6c80f889e7ea92233d5aa7bee to your computer and use it in GitHub Desktop.
Save pszemraj/7fdcaca6c80f889e7ea92233d5aa7bee to your computer and use it in GitHub Desktop.
script to test summarization with the Cohere API
"""
run_cohere_summarization.py - Summarize text files with Co.Summarize API as a python script
"""
import argparse
import json
import logging
import os
import pprint as pp
import random
import shutil
import sys
import time
from datetime import datetime
from pathlib import Path
import cohere
from tqdm import tqdm
def setup_logging(loglevel, logfile=None):
"""Setup basic logging
Args:
loglevel (int): minimum loglevel for emitting messages
logfile (str): path to logfile. If None, log to stderr.
"""
logformat = "[%(asctime)s] %(levelname)s:%(name)s:%(message)s"
# remove any existing handlers
root = logging.getLogger()
if root.handlers:
for handler in root.handlers:
root.removeHandler(handler)
if logfile is None:
# log to console
logging.basicConfig(
level=loglevel,
stream=sys.stdout,
format=logformat,
datefmt="%Y-%m-%d %H:%M:%S",
)
else:
# log to file
logging.basicConfig(
level=loglevel,
filename=logfile,
filemode="w",
format=logformat,
datefmt="%Y-%m-%d %H:%M:%S",
)
def get_parser():
"""Get parser for command line arguments"""
parser = argparse.ArgumentParser(
description="Summarize text files with Co.Summarize API",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
"input_dir", type=Path, help="Input directory containing files to summarize"
)
parser.add_argument(
"-o",
"--output-dir",
type=Path,
default=None,
help="Output directory to save summarized files",
)
parser.add_argument(
"-ext",
"--extension",
type=str,
default=".txt",
help="File extension to look for in input directory",
)
parser.add_argument(
"-m",
"--model",
type=str,
default="summarize-xlarge",
help="summarization model name for cohere.ai summarization",
)
parser.add_argument(
"--length",
type=str,
default="medium",
help="Length of summary options: short, medium, long",
)
parser.add_argument(
"--format",
type=str,
default="paragraph",
help="Format of summary (bullet or paragraph)",
)
parser.add_argument(
"-e",
"--extractiveness",
type=str,
default="low",
help="Extractiveness of summary (low, medium, high)",
)
parser.add_argument(
"-t", "--temperature", type=float, default=0.5, help="Temperature of summary"
)
parser.add_argument(
"--additional-command",
type=str,
default=None,
help="Additional command for summary",
)
parser.add_argument(
"--api-key",
type=str,
default=os.getenv("COHERE_API_KEY"),
help="cohere API key",
)
parser.add_argument(
"--dry-run",
action="store_true",
help="Test script without calling the API or saving any files",
)
parser.add_argument(
"-v",
"--verbose",
dest="loglevel",
help="Set loglevel to INFO",
action="store_const",
const=logging.INFO,
)
parser.add_argument(
"-vv",
"--very-verbose",
dest="loglevel",
help="Set loglevel to DEBUG",
action="store_const",
const=logging.DEBUG,
)
return parser
def summarize_file(
file_path,
client,
model: str = "summarize-xlarge",
length: str = "medium",
output_format: str = "paragraph",
extractiveness: str = "low",
temperature: float = 0.5,
additional_command: str = None,
batch_size: int = 48000,
):
"""
summarize_file - Summarize a text file using the Co.Summarize API
:param Path file_path: Path to file to summarize
:param cohere.Client client: Co.Summarize API client
:param str model: name of model to use for summarization, defaults to "summarize-xlarge"
:param str length: length of summary, defaults to "medium"
:param str output_format: format of summary, defaults to "paragraph"
:param str extractiveness: extractiveness of summary, defaults to "low"
:param float temperature: temperature parameter for summary, defaults to 0.5
:param str additional_command: additional command for summary, defaults to None
:param int batch_size: batch size for summarization, defaults to 48000 characters
:return str: summary of text
"""
with open(file_path, "r", encoding="utf-8", errors="ignore") as f:
text = f.read()
logging.debug(f"Summarizing text with length {len(text)} characters")
if len(text) <= batch_size:
response = client.summarize(
text=text,
model=model,
length=length,
format=output_format,
extractiveness=extractiveness,
temperature=temperature,
additional_command=additional_command,
)
summary = response.summary
else:
batches = [text[i : i + batch_size] for i in range(0, len(text), batch_size)]
summaries = []
for batch in tqdm(batches, desc="batch summaries", total=len(batches)):
response = client.summarize(
text=batch,
model=model,
length=length,
format=output_format,
extractiveness=extractiveness,
temperature=temperature,
additional_command=additional_command,
)
summaries.append(response.summary)
summary = "\n".join(summaries)
logging.debug(f"Summarized file {file_path}")
logging.debug(f"API response: {response}")
return summary
def summarize_files(
input_dir: Path,
output_dir: Path = None,
extension: str = ".txt",
model: str = "summarize-xlarge",
length: str = "medium",
output_format: str = "paragraph",
extractiveness: str = "low",
temperature: float = 0.5,
additional_command: str = None,
api_key: str = None,
dry_run: bool = False,
max_wait: int = 15,
):
"""
summarize_files - summarize text files with Co.Summarize API from cohere
:param Path input_dir: input directory
:param Path output_dir: output directory, defaults to None
:param str extension: file extension to summarize, defaults to ".txt"
:param str model: Co.Summarize model, defaults to "summarize-xlarge"
:param str length: Length of summary options: short, medium, long, defaults to "medium"
:param str format: Format of summary options: paragraph, bullet, defaults to "paragraph"
:param str extractiveness: Extractiveness of summary (low, medium, high), defaults to "low"
:param float temperature: Temperature of summary, defaults to 0.5
:param str additional_command: Additional command for summary, defaults to None
:param str api_key: cohere API key, defaults to None and will error if not set
:param bool dry_run: Test script without calling the API or saving any files, defaults to False
:param int max_wait: Maximum number of seconds to wait for API response, defaults to 15
:return: Path to output directory
"""
assert (
input_dir.is_dir() and input_dir.exists()
), f"Input directory {input_dir} does not exist"
output_dir = (
input_dir.parent / "summarize-cohere" if output_dir is None else output_dir
)
output_dir.mkdir(parents=True, exist_ok=True)
if not api_key:
logging.error(
"No Co.Here API key specified. Set it using the --api-key command line argument or the COHERE_API_KEY environment variable."
)
sys.exit(1)
client = cohere.Client(api_key)
files = [f for f in input_dir.glob(f"*{extension}") if f.is_file() and f.exists()]
logging.info(f"summarizing {len(files)} files in:\n{str(input_dir)}")
for file_path in tqdm(files, desc="Summarizing files", total=len(files)):
logging.info(f"Summarizing file {file_path.name}")
time.sleep(random.randint(1, max_wait))
if dry_run:
continue
try:
summary = summarize_file(
file_path,
client,
model,
length,
output_format,
extractiveness,
temperature,
additional_command,
)
summary_path = output_dir / f"{file_path.stem}_summary{extension}"
with open(summary_path, "w", encoding="utf-8", errors="ignore") as f:
f.write(summary)
logging.info(f"Completed file {file_path.name}")
except Exception as e:
logging.error(f"Error summarizing file {file_path.name}: {e}")
continue
summary_params = {
"input_dir": str(input_dir),
"run_date": datetime.now().strftime("%Y-%b-%d %H:%M"),
"extension": extension,
"model": model,
"length": length,
"format": output_format,
"extractiveness": extractiveness,
"temperature": temperature,
"additional_command": additional_command,
}
if not dry_run:
_params_path = output_dir / "params.json"
with open(_params_path, "w") as f:
json.dump(summary_params, f, indent=4)
logging.info(f"Saved summary parameters to {_params_path}")
logging.info(f"Completed summarization run")
return output_dir
def main():
"""Main function for summarize-cohere script"""
args = get_parser().parse_args()
_logfile_path = Path.cwd() / "LOGFILE_summarize-cohere.log"
setup_logging(args.loglevel, logfile=_logfile_path)
_args = dict(args.__dict__)
_args.pop("api_key") # remove api_key from args
logging.info(f"starting new summarization run with args:\n{pp.pformat(_args)}")
input_dir = Path(args.input_dir).resolve()
assert input_dir.exists(), f"Input directory {input_dir} does not exist"
output_dir = (
Path(args.output_dir)
if args.output_dir is not None
else input_dir.parent / "summarize-cohere"
)
output_dir.mkdir(parents=True, exist_ok=True)
logging.info(
f"Summarizing files in {str(input_dir.resolve())} and saving to:\n\t{str(output_dir.resolve())}"
)
output_dir = summarize_files(
input_dir=input_dir,
output_dir=output_dir,
extension=args.extension,
model=args.model,
length=args.length,
output_format=args.format,
extractiveness=args.extractiveness,
temperature=args.temperature,
additional_command=args.additional_command,
api_key=args.api_key,
dry_run=args.dry_run,
)
shutil.copy(_logfile_path, output_dir / _logfile_path.name)
print(f"Done! Files saved to:\n\t{str(output_dir.resolve())}")
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment