Created
November 16, 2021 12:18
-
-
Save scientificRat/48a1ade35949bd5693687039a1dd2f92 to your computer and use it in GitHub Desktop.
find duplicate images in a folder
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import sys | |
import glob | |
import tqdm | |
import json | |
import shutil | |
import pickle | |
import argparse | |
import numpy as np | |
from PIL import Image | |
from scipy.sparse import csr_matrix | |
from concurrent.futures import ThreadPoolExecutor | |
from scipy.sparse.csgraph import connected_components | |
WORKING_SIZE_W = 8 | |
WORKING_SIZE_H = 8 | |
FEATURE_LEN = WORKING_SIZE_W * WORKING_SIZE_H | |
def img_hash(img_path: str): | |
image = Image.open(img_path).resize((WORKING_SIZE_W, WORKING_SIZE_H), Image.ANTIALIAS).convert('L') | |
img_np = np.array(image.getdata()) | |
avg = img_np.mean() | |
rst_bytes = (img_np > avg).astype(np.uint8) | |
return rst_bytes | |
def _extract_img_hash_worker(img_path: str): | |
return img_path, img_hash(img_path) | |
def build_hash_db(img_file_list: list, out_db_file, show_progress=False): | |
executor = ThreadPoolExecutor(max_workers=8) | |
futures = [] | |
for path in img_file_list: | |
futures.append(executor.submit(_extract_img_hash_worker, path)) | |
rst = [] | |
if show_progress: | |
futures = tqdm.tqdm(futures) | |
for future in futures: | |
img_path, img_feature = future.result() | |
rst.append((img_path, img_feature)) | |
if out_db_file is not None: | |
with open(out_db_file, 'wb') as f: | |
pickle.dump(rst, f) | |
executor.shutdown() | |
return rst | |
def calc_similarity_for_one_query(query_feature, feature_list): | |
dist = np.count_nonzero(query_feature ^ feature_list, axis=1).astype(np.float32) | |
scores = (FEATURE_LEN - dist) * 100 / FEATURE_LEN | |
return scores | |
def pairwise_similarity_in_db(feature_db, show_progress=False): | |
executor = ThreadPoolExecutor(max_workers=6) | |
futures = [] | |
db_size = len(feature_db) | |
path_list, feature_list = list(zip(*feature_db)) | |
feature_list = np.array(feature_list) | |
similar_matrix = [] | |
rows = range(db_size) | |
if show_progress: | |
rows = tqdm.tqdm(rows) | |
for i in rows: | |
futures.append(executor.submit(calc_similarity_for_one_query, feature_list[i], feature_list)) | |
if show_progress: | |
futures = tqdm.tqdm(futures) | |
for future in futures: | |
similar_matrix.append(future.result()) | |
executor.shutdown() | |
del futures, executor | |
similar_matrix = np.asarray(similar_matrix) | |
np.fill_diagonal(similar_matrix, 0) | |
return similar_matrix | |
def generate_connected_graph(feature_db, threshold=90, show_progress=False): | |
similar_matrix = pairwise_similarity_in_db(feature_db, show_progress) | |
connected_matrix = similar_matrix > threshold | |
connected_graph = csr_matrix(connected_matrix) | |
return connected_graph | |
def parse_args(): | |
parser = argparse.ArgumentParser("Duplicate Image Finder") | |
parser.add_argument('--source_dir', type=str, default='', help='source image directory') | |
parser.add_argument('--out_json', type=str, default='images_groups.json', help='similar images groups in json') | |
parser.add_argument('--out_dir', type=str, default='', help='output directory, copy image in grouped names') | |
parser.add_argument('-c', '--threshold', type=int, default=98, help='similar threshold, integer 0-100') | |
return parser.parse_args() | |
def main(): | |
args = parse_args() | |
source_dir = args.source_dir | |
out_dir = args.out_dir | |
if len(out_dir) > 0 and not os.path.exists(out_dir): | |
os.mkdir(out_dir) | |
working_img_list = glob.glob(os.path.join(os.path.abspath(source_dir), "*.jpg")) | |
print("working image cnt", len(working_img_list)) | |
print("building hash db...") | |
sys.stdout.flush() | |
hash_db = build_hash_db(working_img_list, out_db_file=None, show_progress=True) | |
print("generate connected graph...") | |
sys.stdout.flush() | |
connected_graph = generate_connected_graph(hash_db, threshold=95, show_progress=True) | |
print("find connected groups...") | |
n_comp, connected_labels = connected_components(connected_graph) | |
similar_groups = {} | |
for idx, lb in enumerate(connected_labels): | |
if lb not in similar_groups: | |
similar_groups[lb] = [hash_db[idx][0]] | |
else: | |
similar_groups[lb].append(hash_db[idx][0]) | |
similar_groups = similar_groups.values() | |
similar_groups = sorted(similar_groups, key=lambda lst: len(lst)) | |
print("groups cnt", len(similar_groups)) | |
print("ready to output...") | |
if args.out_json is not None and len(args.out_json) > 0: | |
with open(args.out_json, 'w') as f: | |
json.dump(similar_groups, f, ensure_ascii=False) | |
print("similar group json saved to", args.out_json) | |
sys.stdout.flush() | |
if out_dir is not None and len(out_dir) > 0: | |
group_id = 0 | |
for group in tqdm.tqdm(similar_groups): | |
group_id += 1 | |
for path in group: | |
new_name = f"{group_id:04d}_{os.path.basename(path)}" | |
shutil.copy(path, os.path.join(out_dir, new_name)) | |
print("finished, output to", out_dir) | |
if __name__ == '__main__': | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment