Skip to content

Instantly share code, notes, and snippets.

@amenabe22
Created November 1, 2023 20:23
Show Gist options
  • Save amenabe22/517aee4335ce5a455df6440f1846637c to your computer and use it in GitHub Desktop.
Save amenabe22/517aee4335ce5a455df6440f1846637c to your computer and use it in GitHub Desktop.
pariwise_similarity_filter.py
import os
import numpy as np
from PIL import Image
from skimage import io
from brisque import BRISQUE
from skimage.transform import resize
from skimage.metrics import structural_similarity as ssim
def load_images(image_folder, target_size=(256, 256)):
# Load and resize images from a folder
image_files = os.listdir(image_folder)
images = [io.imread(os.path.join(image_folder, image_file))
for image_file in image_files]
images = [resize(image, target_size) for image in images]
return images, image_files
def calculate_similarity(images, win_size=7, channel_axis=2):
# Calculate the similarity matrix using SSIM
n_images = len(images)
similarity_matrix = np.zeros((n_images, n_images))
for i in range(n_images):
for j in range(n_images):
similarity_matrix[i, j] = ssim(
images[i], images[j], win_size=win_size, multichannel=True, data_range=1.0, channel_axis=channel_axis)
return similarity_matrix
def group_images_by_similarity(similarity_matrix, image_files, similarity_threshold):
# Group images based on similarity
n_images = len(image_files)
image_groups = []
grouped = set() # Keep track of images that have already been grouped
for i in range(n_images):
if i not in grouped:
# Start a new group with the first image
group = [{"filename": image_files[i],
"similarity_score": similarity_matrix[1, 1]}]
for j in range(i + 1, n_images):
if similarity_matrix[i, j] > similarity_threshold:
group.append(
{"filename": image_files[j], "similarity_score": similarity_matrix[i, j]})
grouped.add(j) # Mark as grouped
image_groups.append(group)
return image_groups
# get blind quality score of image
def get_brisque_score(input_path):
img = Image.open(input_path).convert("RGB")
brisque = BRISQUE()
score = brisque.score(img)
return score
def main():
similarity_threshold = 0.4
# Load your images into an array "inputs" is the folder where the input images will be
image_folder = os.path.join(os.getcwd(), "inputs")
images, image_files = load_images(image_folder)
similarity_matrix = calculate_similarity(images)
image_groups = group_images_by_similarity(
similarity_matrix, image_files, similarity_threshold)
# Create an instance of BRISQUE
low_quality_threshold = 10.0
bad_images = []
# Calculate BRISQUE scores for each image and add them to the dictionaries
for group in image_groups:
if len(group) == 1:
img_path = os.path.join(image_folder, group[0]["filename"])
score = get_brisque_score(img_path)
group[0]["quality_score"] = score
if score < low_quality_threshold: # Adjust the threshold as needed
bad_images.append(group[0])
elif len(group) > 1:
# bad_images.extend(group[1:])
for img_info in group:
filename = img_info["filename"]
img_path = os.path.join(
image_folder, filename)
score = get_brisque_score(img_path)
# dist_img = prepare_image(dist_pil)
img_info["quality_score"] = score
if not img_info in bad_images:
bad_images.append(img_info)
# combine all clusters and label most similar sets as bad
combined = []
for cluster_id, group in enumerate(image_groups):
if len(group) > 1:
group[0]["bad"] = False
for img_info in group[1:]:
img_info["bad"] = True
else:
group[0]["bad"] = False
combined += group
print("Clustered")
print("-"*30)
print(f"Cluster {cluster_id + 1}: {group}")
print("-"*30)
print("Combined")
print(combined)
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment