Last active
December 16, 2024 16:19
-
-
Save jamessdixon/a85a28b3d6c224282f0cc5f674e4c25d to your computer and use it in GitHub Desktop.
image validation against truth set
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 14, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import warnings\n", | |
"from urllib3.exceptions import NotOpenSSLWarning\n", | |
"warnings.filterwarnings(\"ignore\", category=NotOpenSSLWarning)\n", | |
"\n", | |
"import os\n", | |
"import random\n", | |
"import torch\n", | |
"from PIL import Image\n", | |
"from transformers import CLIPProcessor, CLIPModel" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
".venv/lib/python3.9/site-packages/urllib3/__init__.py:35: NotOpenSSLWarning: urllib3 v2 only supports OpenSSL 1.1.1+, currently the 'ssl' module is compiled with 'LibreSSL 2.8.3'. See: https://github.com/urllib3/urllib3/issues/3020\n", | |
" warnings.warn(\n" | |
] | |
} | |
], | |
"source": [ | |
"# Step 1: Load CLIP model and processor\n", | |
"model = CLIPModel.from_pretrained(\"openai/clip-vit-base-patch32\")\n", | |
"processor = CLIPProcessor.from_pretrained(\"openai/clip-vit-base-patch32\")" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 8, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# Step 2: Generate embeddings for truth set\n", | |
"\n", | |
"truth_set_directory = \"truth_set\"\n", | |
"\n", | |
"truth_images = [\n", | |
" os.path.join(truth_set_directory, f)\n", | |
" for f in os.listdir(truth_set_directory)\n", | |
" if os.path.isfile(os.path.join(truth_set_directory, f))\n", | |
"]\n", | |
"\n", | |
"\n", | |
"truth_embeddings = []\n", | |
"\n", | |
"for image_path in truth_images:\n", | |
" image = Image.open(image_path).convert(\"RGB\")\n", | |
" inputs = processor(images=image, return_tensors=\"pt\", padding=True)\n", | |
" with torch.no_grad():\n", | |
" embedding = model.get_image_features(**inputs).squeeze()\n", | |
" truth_embeddings.append(embedding)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 17, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# Average the truth set embeddings\n", | |
"truth_set_avg_embedding = torch.stack(truth_embeddings).mean(dim=0)\n" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 15, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# Step 3: Generate embedding for query image\n", | |
"\n", | |
"def get_random_query_image(truth_set_directory):\n", | |
" files = [\n", | |
" os.path.join(truth_set_directory, f)\n", | |
" for f in os.listdir(truth_set_directory)\n", | |
" if os.path.isfile(os.path.join(truth_set_directory, f))\n", | |
" ]\n", | |
" \n", | |
" if not files:\n", | |
" raise ValueError(f\"No files found in the directory: {truth_set_directory}\")\n", | |
" \n", | |
" query_image = random.choice(files)\n", | |
" return query_image\n", | |
"\n", | |
"truth_set_directory = \"truth_set\"\n", | |
"query_image_path = get_random_query_image(truth_set_directory)\n", | |
"query_image = Image.open(query_image_path).convert(\"RGB\")\n", | |
"query_inputs = processor(images=query_image, return_tensors=\"pt\", padding=True)\n", | |
"with torch.no_grad():\n", | |
" query_embedding = model.get_image_features(**query_inputs).squeeze()\n" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 16, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Query image is similar to the truth set: True\n" | |
] | |
} | |
], | |
"source": [ | |
"# Compare using cosine similarity\n", | |
"cosine_sim = torch.nn.functional.cosine_similarity(query_embedding, truth_set_avg_embedding, dim=0)\n", | |
"threshold = 0.8\n", | |
"is_similar = cosine_sim > threshold\n", | |
"\n", | |
"print(f\"Query image is similar to the truth set: {is_similar.item()}\")" | |
] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": ".venv", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.9.6" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment