Skip to content

Instantly share code, notes, and snippets.

@abodacs
Forked from jxmorris12/hf_dedup.py
Created December 12, 2024 11:16
Show Gist options
  • Save abodacs/570462fc984ec0dee3916dcde21146cb to your computer and use it in GitHub Desktop.
Save abodacs/570462fc984ec0dee3916dcde21146cb to your computer and use it in GitHub Desktop.
import multiprocessing
manager = multiprocessing.Manager()
all_hashes_set = manager.dict()
def deduplicate(examples, all_hashes_set):
print(len(all_hashes_set))
input_ids = examples['input_ids']
hashes = [
hash(tuple(input_ids[i]))
for i in range(len(input_ids))
]
shou/ld_filter_ex = []
for val in hashes:
if val in all_hashes_set:
should_filter_ex.append(True)
else:
should_filter_ex.append(False)
all_hashes_set[val] = 1
return should_filter_ex
original_len = len(dataset)
dataset = dataset.filter(
deduplicate,
batched=True,
num_proc=os.cpu_count(),
fn_kwargs = { "all_hashes_set": all_hashes_set }
)
print(f"Removed {original_len - len(dataset)} duplicates")
@abodacs
Copy link
Author

abodacs commented Dec 12, 2024

Awesome ! I wonder if we can make it faster using arrow functions (when the dataset format is set to "arrow")

Cc
@kszucs_
for hashing functions in pyarrow.compute

https://x.com/jxmnop/status/1866939467094954225

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment