Skip to content

Instantly share code, notes, and snippets.

@philschmid
Created January 13, 2025 16:02
Show Gist options
  • Save philschmid/73ad56f0b07060cbd3fc4787a19fa91e to your computer and use it in GitHub Desktop.
Save philschmid/73ad56f0b07060cbd3fc4787a19fa91e to your computer and use it in GitHub Desktop.
from time import time
from datasets import load_dataset
from semhash import SemHash
# if greater than 0.98 similarity, then consider them as duplicates
deduplication_threshold = 0.98
# Load a dataset to deduplicate
ds = load_dataset("arcee-ai/The-Tome", split="train")
# convert message to prompt test
texts = ds.map(lambda x: {"text": x["conversations"][0]["value"]}, num_proc=8)["text"]
print(
f"Deduplicating 'arcee-ai/The-Tome' dataset with threshold {deduplication_threshold} and length {len(texts)}"
)
# Initialize a SemHash instance
start_time = time()
semhash = SemHash.from_records(records=texts)
end_time = time()
# Deduplicate the texts
deduplication_result = semhash.self_deduplicate(threshold=deduplication_threshold)
print(
f"Total number of deduplicated prompts with threshold {deduplication_threshold}: {len(deduplication_result.deduplicated)}"
)
print(
f"Removed {len(texts) - len(deduplication_result.deduplicated)} prompts, with a duplicate ratio of {deduplication_result.duplicate_ratio:.2f}% and {deduplication_result.exact_duplicate_ratio:.2f}% exact duplicates"
)
print(f"Time taken to deduplicate: {end_time - start_time:.2f} seconds")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment