Created
November 1, 2023 20:23
-
-
Save amenabe22/517aee4335ce5a455df6440f1846637c to your computer and use it in GitHub Desktop.
pariwise_similarity_filter.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import numpy as np | |
from PIL import Image | |
from skimage import io | |
from brisque import BRISQUE | |
from skimage.transform import resize | |
from skimage.metrics import structural_similarity as ssim | |
def load_images(image_folder, target_size=(256, 256)): | |
# Load and resize images from a folder | |
image_files = os.listdir(image_folder) | |
images = [io.imread(os.path.join(image_folder, image_file)) | |
for image_file in image_files] | |
images = [resize(image, target_size) for image in images] | |
return images, image_files | |
def calculate_similarity(images, win_size=7, channel_axis=2): | |
# Calculate the similarity matrix using SSIM | |
n_images = len(images) | |
similarity_matrix = np.zeros((n_images, n_images)) | |
for i in range(n_images): | |
for j in range(n_images): | |
similarity_matrix[i, j] = ssim( | |
images[i], images[j], win_size=win_size, multichannel=True, data_range=1.0, channel_axis=channel_axis) | |
return similarity_matrix | |
def group_images_by_similarity(similarity_matrix, image_files, similarity_threshold): | |
# Group images based on similarity | |
n_images = len(image_files) | |
image_groups = [] | |
grouped = set() # Keep track of images that have already been grouped | |
for i in range(n_images): | |
if i not in grouped: | |
# Start a new group with the first image | |
group = [{"filename": image_files[i], | |
"similarity_score": similarity_matrix[1, 1]}] | |
for j in range(i + 1, n_images): | |
if similarity_matrix[i, j] > similarity_threshold: | |
group.append( | |
{"filename": image_files[j], "similarity_score": similarity_matrix[i, j]}) | |
grouped.add(j) # Mark as grouped | |
image_groups.append(group) | |
return image_groups | |
# get blind quality score of image | |
def get_brisque_score(input_path): | |
img = Image.open(input_path).convert("RGB") | |
brisque = BRISQUE() | |
score = brisque.score(img) | |
return score | |
def main(): | |
similarity_threshold = 0.4 | |
# Load your images into an array "inputs" is the folder where the input images will be | |
image_folder = os.path.join(os.getcwd(), "inputs") | |
images, image_files = load_images(image_folder) | |
similarity_matrix = calculate_similarity(images) | |
image_groups = group_images_by_similarity( | |
similarity_matrix, image_files, similarity_threshold) | |
# Create an instance of BRISQUE | |
low_quality_threshold = 10.0 | |
bad_images = [] | |
# Calculate BRISQUE scores for each image and add them to the dictionaries | |
for group in image_groups: | |
if len(group) == 1: | |
img_path = os.path.join(image_folder, group[0]["filename"]) | |
score = get_brisque_score(img_path) | |
group[0]["quality_score"] = score | |
if score < low_quality_threshold: # Adjust the threshold as needed | |
bad_images.append(group[0]) | |
elif len(group) > 1: | |
# bad_images.extend(group[1:]) | |
for img_info in group: | |
filename = img_info["filename"] | |
img_path = os.path.join( | |
image_folder, filename) | |
score = get_brisque_score(img_path) | |
# dist_img = prepare_image(dist_pil) | |
img_info["quality_score"] = score | |
if not img_info in bad_images: | |
bad_images.append(img_info) | |
# combine all clusters and label most similar sets as bad | |
combined = [] | |
for cluster_id, group in enumerate(image_groups): | |
if len(group) > 1: | |
group[0]["bad"] = False | |
for img_info in group[1:]: | |
img_info["bad"] = True | |
else: | |
group[0]["bad"] = False | |
combined += group | |
print("Clustered") | |
print("-"*30) | |
print(f"Cluster {cluster_id + 1}: {group}") | |
print("-"*30) | |
print("Combined") | |
print(combined) | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment