Skip to content

Instantly share code, notes, and snippets.

@jxmorris12
Created December 5, 2024 19:49
Show Gist options
  • Save jxmorris12/7ff8b8b51a906a98826b694ba28ae275 to your computer and use it in GitHub Desktop.
Save jxmorris12/7ff8b8b51a906a98826b694ba28ae275 to your computer and use it in GitHub Desktop.
import multiprocessing
manager = multiprocessing.Manager()
all_hashes_set = manager.dict()
def deduplicate(examples, all_hashes_set):
print(len(all_hashes_set))
input_ids = examples['input_ids']
hashes = [
hash(tuple(input_ids[i]))
for i in range(len(input_ids))
]
shou/ld_filter_ex = []
for val in hashes:
if val in all_hashes_set:
should_filter_ex.append(True)
else:
should_filter_ex.append(False)
all_hashes_set[val] = 1
return should_filter_ex
original_len = len(dataset)
dataset = dataset.filter(
deduplicate,
batched=True,
num_proc=os.cpu_count(),
fn_kwargs = { "all_hashes_set": all_hashes_set }
)
print(f"Removed {original_len - len(dataset)} duplicates")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment