Created
July 30, 2022 03:31
-
-
Save conceptofmind/c5804428ea1bd89767815f9cd5f02d9a to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import json | |
import multiprocessing as mp | |
import re | |
from collections import defaultdict | |
from functools import partial | |
from typing import Dict, List, Optional, Set, Tuple, Type | |
from datasets import Dataset | |
from tqdm import tqdm | |
from datasketch import MinHash, MinHashLSH | |
from dpu_utils.utils.iterators import ThreadedIterator | |
NON_ALPHA = re.compile("[^A-Za-z_0-9]") | |
# parameters used in DuplicationIndex | |
MIN_NUM_TOKENS = 10 | |
NUM_PERM = 256 | |
def get_min_hash(tokens: List[str]) -> Optional[MinHash]: | |
"""Compute the MinHash of a code snippet.""" | |
if len(tokens) < MIN_NUM_TOKENS: | |
return None | |
min_hash = MinHash(num_perm=NUM_PERM) | |
for token in set(tokens): | |
min_hash.update(token.encode()) | |
return min_hash | |
def get_tokens(code: str) -> Set[str]: | |
"""Tokenize a code snippet.""" | |
return set([t for t in NON_ALPHA.split(code) if len(t.strip()) > 0]) | |
class DuplicationIndex: | |
def __init__( | |
self, | |
*, | |
duplication_jaccard_threshold: float = 0.95, | |
): | |
self._duplication_jaccard_threshold = duplication_jaccard_threshold | |
self._num_perm = NUM_PERM | |
self._index = MinHashLSH(threshold=self._duplication_jaccard_threshold, num_perm=self._num_perm) | |
self._duplicate_clusters = defaultdict(set) | |
def add(self, code_key: Tuple, min_hash: MinHash) -> None: | |
"""Add a key to _index (MinHashLSH) | |
the min_hash is used to query closest matches based on the jaccard_threshold. | |
The new key is either added to a existing cluster of one close match, | |
or a new cluster is created. The clusters created in this way, depend on the order of add. | |
Args: | |
code_key (Tuple of (index, repo_name, path)): | |
Theoritically any hasbale key. Here we use a tuple to retrieve the information later. | |
min_hash: MinHash of the code_key. | |
""" | |
close_duplicates = self._index.query(min_hash) | |
if code_key in self._index.keys: | |
print(f"Duplicate key {code_key}") | |
return | |
self._index.insert(code_key, min_hash) | |
if len(close_duplicates) > 0: | |
for base_duplicate in close_duplicates: | |
if base_duplicate in self._duplicate_clusters: | |
self._duplicate_clusters[base_duplicate].add(code_key) | |
break | |
else: | |
self._duplicate_clusters[close_duplicates[0]].add(code_key) | |
def get_duplicate_clusters(self) -> List[List[Dict]]: | |
"""Export the duplicate clusters. | |
For each cluster, the first element is the base element of the cluster. | |
The base element has an estimation jaccard similarity higher than the threshold with all the other elements. | |
Returns: | |
duplicate_clusters (List[List[Dict]]): | |
List of duplicate clusters. | |
""" | |
duplicate_clusters = [] | |
for base, duplicates in self._duplicate_clusters.items(): | |
cluster = [base] + list(duplicates) | |
# reformat the cluster to be a list of dict | |
cluster = [{"base_index": el} for el in cluster] | |
duplicate_clusters.append(cluster) | |
return duplicate_clusters | |
def save(self, filepath) -> None: | |
duplicate_clusters = self.get_duplicate_clusters() | |
with open(filepath, "w") as f: | |
json.dump(duplicate_clusters, f) | |
def _compute_min_hash(element): | |
index, data = element | |
min_hash = get_min_hash([t for t in NON_ALPHA.split(data["text"]) if len(t.strip()) > 0]) | |
if min_hash is not None: | |
return index, min_hash | |
def minhash_iter(dataset_iterator: Type[Dataset]): | |
with mp.Pool() as pool: | |
for data in pool.imap_unordered( | |
_compute_min_hash, | |
ThreadedIterator(dataset_iterator, max_queue_size=10000), | |
chunksize=100, | |
): | |
if data is not None: | |
yield data | |
def make_duplicate_clusters(dataset_iterator: Type[Dataset], jaccard_threshold: float): | |
"""Find duplicate clusters in the dataset in two steps: | |
1. Compute MinHash for each code snippet. MinHash is a tool for fast jaccard similarity estimation. | |
This step is computed using an asynchronous multiprocessing pool, minhash_iter | |
2. Find duplicate clusters. The computed MinHash is added sequentially to the DuplicationIndex. | |
This step cannot be parallelized. So using asynchronous thread in the previous step helps to speed up the process. | |
""" | |
di = DuplicationIndex(duplication_jaccard_threshold=jaccard_threshold) | |
for filename, min_hash in tqdm(ThreadedIterator(minhash_iter(enumerate(dataset_iterator)), max_queue_size=100)): | |
di.add(filename, min_hash) | |
# Returns a List[Cluster] where Cluster is List[str] with the filenames. | |
return di.get_duplicate_clusters() | |
def jaccard_similarity(code1: str, code2: str) -> float: | |
"""Compute the Jaccard similarity of two code snippets.""" | |
tokens1 = get_tokens(code1) | |
tokens2 = get_tokens(code2) | |
return len(tokens1 & tokens2) / len(tokens1 | tokens2) | |
_shared_dataset = None | |
def _find_cluster_extremes_shared(cluster, jaccard_threshold): | |
"""Find a reduced cluster such that each code in the origin cluster is similar to at least one code in the reduced cluster. | |
Two codes are similar if their Jaccard similarity is above the threshold. | |
Args: | |
cluster (List[dict]): | |
cluster is a list of dict, each dict contains the following keys: | |
- base_index | |
- repo_name | |
- path | |
This is a typical output of DuplicationIndex.get_duplicate_clusters() | |
jaccard_threshold (float): | |
threshold for Jaccard similarity. | |
Two codes are similar if their Jaccard similarity is above the threshold. | |
Returns: | |
extremes (List[dict]): | |
A reduced representation of the cluster. The field copies is added to each dict. | |
The copies field indicates the number of similar codes in the cluster for a extreme. | |
""" | |
extremes = [] | |
for element1 in cluster: | |
code1 = _shared_dataset[element1["base_index"]]["text"] | |
for element2 in extremes: | |
code2 = _shared_dataset[element2["base_index"]]["text"] | |
if jaccard_similarity(code1, code2) >= jaccard_threshold: | |
element2["copies"] += 1 | |
break | |
else: | |
element1["copies"] = 1 | |
extremes.append(element1) | |
return extremes | |
def find_extremes(cluster_list, dataset, jaccard_threshold): | |
"""Call the _find_cluster_extremes_shared function in a parallel fashion. | |
Args: | |
cluster_list (List[List[Dict]]): | |
each cluster is a list of dicts with the key base_index, | |
referring to the index of the base code in the dataset. | |
dataset (Type[Dataset]): | |
dataset is used to access the content of the code snippets, | |
using the base_index from the cluster_list. | |
dataset is shared between all the processes using a glabal variable (any other way to share the dataset?), | |
otherwise the multi processing is not speeded up. | |
jaccard_threshold (float): | |
the threshold for the jaccard similarity. The default value is 0.85 | |
Returns: | |
extremes_list (List[Dict]): | |
Each cluster is reduced to extremes. | |
See _find_cluster_extremes_shared for the definition of extremes. | |
""" | |
global _shared_dataset | |
_shared_dataset = dataset | |
extremes_list = [] | |
f = partial(_find_cluster_extremes_shared, jaccard_threshold=jaccard_threshold) | |
with mp.Pool() as pool: | |
for extremes in tqdm( | |
pool.imap_unordered( | |
f, | |
cluster_list, | |
), | |
total=len(cluster_list), | |
): | |
extremes_list.append(extremes) | |
return extremes_list | |
def deduplicate_dataset( | |
dataset: Type[Dataset], jaccard_threshold: float = 0.85 | |
) -> Tuple[Type[Dataset], List[List[Dict]]]: | |
"""Deduplicate the dataset using minhash and jaccard similarity. | |
This function first generate duplicate clusters, then each cluster | |
is reduced to the extremes that are similar to the other elements in the cluster. | |
Codes are called similar if their Jaccard similarity is greater than jaccard_threshold (0.85 default). | |
Args: | |
dataset (Type[Dataset]): | |
The dataset to deduplicate. | |
jaccard_threshold (float, default=0.95): | |
jaccard threshold to determine if two codes are similar | |
Returns: | |
ds_dedup (Type[Dataset]): | |
The deduplicated dataset. | |
duplicate_clusters (List[List[Dict]]): | |
The list of duplicate clusters. | |
Each cluster is a list of dicts with the following keys: | |
- base_index : int | |
The index of the code in the original dataset. | |
- repo_name : str | |
- path : str | |
- copies : int | |
The number of copies of the code in the cluster. (find_cluster_extremes) | |
- is_extreme : bool | |
Whether the code is an extreme in the cluster. | |
All the codes in the cluster are removed from the dataset except the extremes. | |
Example: | |
>>> from datasets import load_dataset | |
>>> from minhash_deduplication import deduplicate_dataset | |
>>> ds = load_dataset("lvwerra/codeparrot-clean", split="train") | |
>>> ds_dedup, duplicate_clusters = deduplicate_dataset(ds, jaccard_threshold=0.85) | |
""" | |
duplicate_clusters = make_duplicate_clusters(dataset, jaccard_threshold) | |
duplicate_indices = set(x["base_index"] for cluster in duplicate_clusters for x in cluster) | |
extreme_dict = {} | |
extremes_clusters = find_extremes(duplicate_clusters, dataset, jaccard_threshold) | |
for extremes in extremes_clusters: | |
for element in extremes: | |
extreme_dict[element["base_index"]] = element | |
remove_indices = duplicate_indices - set(extreme_dict.keys()) | |
ds_filter = dataset.filter(lambda x, idx: idx not in remove_indices, with_indices=True) | |
# update duplicate_clusters | |
for cluster in duplicate_clusters: | |
for element in cluster: | |
element["is_extreme"] = element["base_index"] in extreme_dict | |
if element["is_extreme"]: | |
element["copies"] = extreme_dict[element["base_index"]]["copies"] | |
print(f"Original dataset size: {len(dataset)}") | |
print(f"Number of duplicate clusters: {len(duplicate_clusters)}") | |
print(f"Files in duplicate cluster: {len(duplicate_indices)}") | |
print(f"Unique files in duplicate cluster: {len(extreme_dict)}") | |
print(f"Filtered dataset size: {len(ds_filter)}") | |
return ds_filter, duplicate_clusters |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment