Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save datavudeja/806894ba0ac475e7e5ddd86876451c28 to your computer and use it in GitHub Desktop.
Save datavudeja/806894ba0ac475e7e5ddd86876451c28 to your computer and use it in GitHub Desktop.
"""Summary
"""
import logging
from pathlib import Path
import fire
from datasets import Dataset, load_dataset
from tqdm.auto import tqdm
from transformers import AutoTokenizer
# Setup logging
logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
)
def is_length_appropriate(
text,
tokenizer,
min_length=1000,
max_length=16384,
):
"""
Check if the tokenized text length falls within specified bounds.
Args:
text (str): The text to be tokenized.
tokenizer: Tokenizer object.
min_length (int): Minimum token length.
max_length (int): Maximum token length.
No Longer Returned:
bool: True if text length is appropriate, False otherwise.
"""
tokenized = tokenizer(text)
token_length = len(tokenized["input_ids"])
return min_length < token_length < max_length
def save_dataset(dataset, save_path):
"""
Save the dataset to the specified path.
Args:
dataset: The dataset to save.
save_path (str): Path to save the dataset.
"""
save_path = Path(save_path)
dataset.save_to_disk(str(save_path))
logging.info(f"Dataset saved to {save_path}")
def main(
dataset_name,
config_name="default",
text_column="text",
save_path: str = None,
tokenizer_name="pszemraj/long-t5-tglobal-base-16384-book-summary",
max_samples=500,
):
"""
Main function to process and save the dataset.
Args:
dataset_name (str): Name of the dataset.
config_name (str, optional): Description
text_column (str, optional): Description
save_path (str, optional): Path to save the processed dataset.
tokenizer_name (str): Name of the tokenizer.
max_samples (int): Maximum number of samples to save.
"""
ds_short_name = dataset_name.split("/")[-1]
# Load the tokenizer
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
# Load the dataset in streaming mode
streaming_dataset = load_dataset(
dataset_name, config_name, split="train", streaming=True
)
# Initialize variables for filtered data
filtered_texts = []
# Process the dataset
for sample in tqdm(streaming_dataset, desc="Processing Dataset"):
if is_length_appropriate(sample[text_column], tokenizer):
filtered_texts.append(sample[text_column])
# Stop after collecting max_samples
if len(filtered_texts) == max_samples:
logging.warning(f"hit max samples: {max_samples}")
break
# Create a dataset from the filtered texts
filtered_dataset = Dataset.from_dict({"text": filtered_texts})
# Save the dataset
max_samples = len(filtered_texts)
save_path = (
save_path
or Path("exported-datasets") / f"{ds_short_name}_filtered-{max_samples}-samples"
)
save_path = Path(save_path)
save_path.mkdir(parents=True, exist_ok=True)
logging.info(f"Saving dataset to {save_path}...")
save_dataset(filtered_dataset, save_path)
logging.info("done!")
if __name__ == "__main__":
fire.Fire(main)
import json
import logging
import re
from datetime import datetime
from pathlib import Path
import fire
import joblib
import pandas as pd
from cleantext import clean
from joblib import Memory
from langchain.chains.summarize import load_summarize_chain
from langchain.chat_models import ChatOpenAI
from langchain.text_splitter import RecursiveCharacterTextSplitter
from tqdm.auto import tqdm
from datasets import Dataset, DatasetDict, load_dataset
logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
# Encoding map for models
ENCODING_MAP = {
"gpt-3.5-turbo": "cl100k_base",
"gpt-3.5-turbo-1106": "cl100k_base",
"gpt-3.5-turbo-16k": "cl100k_base",
"gpt-4": "cl100k_base",
"gpt-4-0613": "cl100k_base",
"gpt-4-0314": "cl100k_base",
"gpt-4-32k": "cl100k_base",
}
_here = Path(__file__).parent
_checkpoints = _here / "checkpoints"
_checkpoints.mkdir(exist_ok=True)
memory_loc = Path.home() / ".cache" / "langchain" / "openai-summarize"
memory_loc.mkdir(parents=True, exist_ok=True)
memory = Memory(memory_loc, verbose=0)
import threading
from functools import wraps
from tenacity import retry, stop_after_attempt, wait_exponential
def get_wordcount(text):
"""robust word count"""
return len(re.findall(r"\w+", text))
import queue
import threading
from functools import wraps
def timeout(seconds):
def decorator(func):
@wraps(func)
def wrapper(*args, **kwargs):
q = queue.Queue()
def newFunc():
try:
q.put(func(*args, **kwargs))
except Exception as e:
q.put(e)
thread = threading.Thread(target=newFunc)
thread.start()
thread.join(seconds)
if thread.is_alive():
print(f"Timeout exceeded for '{func.__name__}' after {seconds} seconds")
thread.join()
raise TimeoutError(
f"function [{func.__name__}] timeout [{seconds} seconds] exceeded!"
)
result = q.get_nowait()
if isinstance(result, Exception):
raise result
return result
return wrapper
return decorator
@memory.cache
@retry(stop=stop_after_attempt(4), wait=wait_exponential(multiplier=1, min=4, max=10))
def map_reduce_summary(
text,
model,
temperature: float = 0.0,
chunk_size: int = 14336,
chunk_overlap: int = 32,
):
try:
llm = ChatOpenAI(model_name=model, temperature=temperature, request_timeout=20)
text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(
encoding_name=ENCODING_MAP.get(model, "gpt2"),
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
)
chain_map_reduce = load_summarize_chain(
llm, chain_type="map_reduce", return_intermediate_steps=True
)
input_text = clean(text, lower=False, lang="en")
docs = text_splitter.create_documents([input_text])
map_reduce_output = chain_map_reduce(
{"input_documents": docs}, return_only_outputs=True
)
return map_reduce_output["output_text"]
except Exception as e:
wc = get_wordcount(text)
logging.error(f"Error summarizing text with {wc} words: {e[:100]}")
return ""
def summarize_text(
df, model, chunk_size, chunk_overlap, temperature, save_checkpoint=True
):
"""
Summarizes the text using the specified model and returns the DataFrame with summaries.
"""
logging.info(f"using model: {model}")
df["summary"] = ""
for index, row in tqdm(df.iterrows(), total=len(df), desc="Summarizing"):
df.loc[index, "summary"] = map_reduce_summary(
row["text"], model, temperature, chunk_size, chunk_overlap
)
_initial_rows = len(df)
df = df[df["summary"].str.len() > 0]
logging.info(
f"Removed {len(df) - _initial_rows} rows with empty summaries. {len(df)} rows remaining."
)
if save_checkpoint:
_chk_path = (
_checkpoints / f"{model}-{datetime.now().strftime('%Y%b%d_%H')}.joblib"
)
joblib.dump(
df,
_chk_path,
compress=3,
)
logging.info(f"Checkpoint saved to:\n\t{_chk_path}")
return df
def split_dataset(df, val_test_size):
"""
Splits the dataset into training, validation, and test sets.
"""
ds = Dataset.from_pandas(df)
if val_test_size > 0:
train_temp_split = ds.train_test_split(
test_size=val_test_size, shuffle=True, seed=42
)
train_dataset = train_temp_split["train"]
temp_dataset = train_temp_split["test"]
val_test_split = temp_dataset.train_test_split(
test_size=0.5, shuffle=True, seed=42
)
return DatasetDict(
{
"train": train_dataset,
"validation": val_test_split["train"],
"test": val_test_split["test"],
}
)
else:
return DatasetDict({"train": ds})
def save_or_push_dataset(dataset, save_dir, repo_id, output_config_name, private):
"""
Saves the dataset locally or pushes it to Huggingface hub.
"""
if repo_id:
logging.info(f"Pushing dataset to Huggingface hub ({repo_id})...")
if output_config_name:
dataset.push_to_hub(
repo_id=repo_id, config_name=output_config_name, private=private
)
else:
dataset.push_to_hub(repo_id=repo_id, private=private)
logging.info(f"Dataset pushed to {repo_id}")
else:
save_dir = (
save_dir or Path("datasets") / f"{datetime.now().strftime('%Y%m%d%H%M%S')}"
)
save_dir.mkdir(parents=True, exist_ok=True)
logging.info(f"Saving dataset to {save_dir}...")
dataset.save_to_disk(str(save_dir))
logging.info(f"Dataset saved to {save_dir}")
def main(
dataset_name,
config_name: str = "default",
model="gpt-3.5-turbo-1106",
chunk_size=14336,
chunk_overlap=32,
temperature=0.0,
val_test_size=0.1,
repo_id=None,
output_config_name=None,
private=True,
save_dir=None,
max_samples: int = 500,
):
"""
The main function that processes and summarizes the text files.
"""
logging.info(f"Starting processing, dataset: {dataset_name}, config: {config_name}")
if Path(dataset_name).is_dir() and Path(dataset_name).exists():
logging.info("Loading dataset from disk...")
dataset = Dataset.load_from_disk(dataset_name)
df = dataset.to_pandas().convert_dtypes()
else:
logging.info("Loading dataset from Huggingface hub...")
dataset = load_dataset(dataset_name, config_name)
df = dataset["train"].to_pandas().convert_dtypes()
if max_samples and len(df) > max_samples:
df = df.sample(n=max_samples, random_state=80085)
print(df.info())
df = summarize_text(df, model, chunk_size, chunk_overlap, temperature)
dataset = split_dataset(df, val_test_size)
save_or_push_dataset(dataset, save_dir, repo_id, output_config_name, private)
logging.info("Processing complete.")
if __name__ == "__main__":
fire.Fire(main)
"""\
Refactored script for summarizing texts using a deep learning model.
Check the docs to see how to install it so it uses your hardware:
References:
https://python.langchain.com/docs/integrations/llms/llamacpp
https://llama-cpp-python.readthedocs.io/en/latest/api-reference/
"""
import json
import logging
import random
import time
from datetime import datetime
from pathlib import Path
import fire
from cleantext import clean
from datasets import load_dataset, load_from_disk
from langchain.callbacks.manager import CallbackManager
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
from langchain.chains import LLMChain
from langchain.llms import LlamaCpp
from langchain.prompts import PromptTemplate
from langchain.text_splitter import RecursiveCharacterTextSplitter
from tqdm.auto import tqdm
from transformers import AutoTokenizer
N_PREVIEW_CHARS = 600
CACHE_LOC = Path("/mnt/c/Users/peter/.cache/ggml-models")
assert (
CACHE_LOC.exists() and CACHE_LOC.is_dir()
), f"wrong default cache, doesn't exist or not dir: {CACHE_LOC}"
# Initialize logging
logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
)
# Template for the prompt
template = """GPT4 User: {prompt}GPT4 Assistant:"""
prompt = PromptTemplate(template=template, input_variables=["prompt"])
callback_manager = CallbackManager([StreamingStdOutCallbackHandler()])
class Timer:
"""
Basic timer utility.
"""
def __enter__(self):
self.start_time = time.perf_counter()
return self
def __exit__(self, exc_type, exc_value, traceback):
self.end_time = time.perf_counter()
self.elapsed_time = self.end_time - self.start_time
logging.info(f"Elapsed time: {self.elapsed_time:.4f} seconds")
# Map function
def custom_map_function(document, llm):
map_template = """GPT4 User: The following is a document to summarize:
{document}
Please create a comprehensive summary of the above document.<|end_of_turn|>GPT4 Assistant:"""
map_prompt = PromptTemplate.from_template(map_template)
map_chain = LLMChain(llm=llm, prompt=map_prompt)
# Process the document to generate a summary
return map_chain.run({"document": document})
# Reduce function
def custom_reduce_function(mapped_summaries, llm):
reduce_template = """GPT4 User: The following are 'batched' summaries from parts of one long document:
{summaries}
Aggregate the above summaries into a holistic, clearly written overview. Add details and explainations as needed.<|end_of_turn|>GPT4 Assistant:"""
reduce_prompt = PromptTemplate.from_template(reduce_template)
reduce_chain = LLMChain(llm=llm, prompt=reduce_prompt)
# Combine summaries into a final output
return reduce_chain.run({"summaries": "\n".join(mapped_summaries)})
def summarize_text(
text,
llm,
chunk_size: int = 3072,
chunk_overlap: int = 16,
disable_progress=False,
clean_before=False,
):
"""
Summarizes the text using the custom map-reduce approach.
"""
text_splitter = RecursiveCharacterTextSplitter.from_huggingface_tokenizer(
tokenizer=AutoTokenizer.from_pretrained("openchat/openchat_3.5"),
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
)
docs = text_splitter.create_documents([text])
docs = (
[clean(d, lower=False, keep_two_line_breaks=True) for d in docs]
if clean_before
else docs
)
logging.debug(f"Number of documents: {len(docs)}")
# Map step: Process each document to generate a summary
mapped_docs = [
custom_map_function(doc, llm)
for doc in tqdm(docs, desc="main summary batches", disable=disable_progress)
]
if len(mapped_docs) == 1:
return clean(mapped_docs[0], lower=False)
# Reduce step: Combine the summaries into a final output
combined_summary = custom_reduce_function(mapped_docs, llm)
def main(
model_id: str = "openchat_3.5-16k.Q4_K_M.gguf",
dataset_name: str = "BEE-spoke-data/govt-pptx",
dataset_config: str = "default",
out_dir: str = None,
chunk_size: int = 3100,
temperature=0.05,
n_ctx=4096,
max_tokens=512,
top_p=1,
numa=True,
n_batch: int = 256,
n_gpu_layers: int = 0,
verbose=False,
**kwargs,
):
"""
Main function for summarizing text using a specified model.
"""
model_path = CACHE_LOC / model_id
assert model_path.exists(), f"{model_path} does not exist"
if Path(dataset_name).exists():
logging.info(f"Loading dataset from disk: {dataset_name}")
dataset = load_from_disk(dataset_name)
text = random.choice(dataset["text"])
else:
dataset = load_dataset(dataset_name, dataset_config)
text = random.choice(dataset["train"]["text"])
logging.info(f"Loading model from {model_path.name}")
llm = LlamaCpp(
model_path=str(model_path.resolve()),
temperature=temperature,
n_ctx=n_ctx,
max_tokens=max_tokens,
top_p=top_p,
numa=numa,
n_batch=n_batch,
n_gpu_layers=n_gpu_layers,
verbose=verbose,
**kwargs,
)
logging.info(
f"Input text (first {N_PREVIEW_CHARS}):\n\n{text[:N_PREVIEW_CHARS]} ...\n\n"
)
with Timer() as timer: # noqa
summary = summarize_text(text, llm, chunk_size=chunk_size)
logging.info(f"Summary:\n\t{summary}")
report = {
"metadata": {
"model": model_id,
"dataset": dataset_name,
"dataset_config": dataset_config,
"chunk_size": chunk_size,
"temperature": temperature,
"n_ctx": n_ctx,
"max_tokens": max_tokens,
"top_p": top_p,
"numa": numa,
"n_batch": n_batch,
"n_gpu_layers": n_gpu_layers,
"verbose": verbose,
},
"summary": summary,
"text": text,
}
out_dir = Path(out_dir) if out_dir is not None else Path("langchain_summaries")
out_dir.mkdir(exist_ok=True)
with open(
out_dir / f"{datetime.now().strftime('%Y%b%d%H%M%S')}.json",
"w",
encoding="utf-8",
) as f:
json.dump(report, f, indent=4)
logging.info(f"Report saved to {out_dir}")
if __name__ == "__main__":
fire.Fire(main)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment