Skip to content

Instantly share code, notes, and snippets.

@fivejjs
Forked from tanukon/embedding_comparison.ipynb
Created October 17, 2024 10:16
Show Gist options
  • Save fivejjs/30738f060d80df8deefbf7390e2ff3f3 to your computer and use it in GitHub Desktop.
Save fivejjs/30738f060d80df8deefbf7390e2ff3f3 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/opt/conda/envs/transformers-env/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
" from .autonotebook import tqdm as notebook_tqdm\n"
]
}
],
"source": [
"import faiss\n",
"import os\n",
"import numpy as np\n",
"import pandas as pd\n",
"import torch\n",
"\n",
"from PIL import Image\n",
"from tqdm import tqdm\n",
"from transformers import AutoImageProcessor, EfficientNetModel, ViTModel, AutoModel, CLIPProcessor, CLIPModel, Blip2Processor, Blip2Model\n",
"\n",
"from transformers import Pipeline\n",
"from transformers.image_utils import load_image"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"# dataset directory for flickr30k dataset\n",
"dataset_dir = '<Your dataset directory>'"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"(500, 375)\n"
]
}
],
"source": [
"test_image = Image.open(os.path.join(dataset_dir, '36979.jpg'))\n",
"print(test_image.size)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## EfficientNet feature extraction"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"# load pre-trained image processor for efficientnet-b7 and model weight\n",
"image_processor = AutoImageProcessor.from_pretrained(\"google/efficientnet-b7\")\n",
"model = EfficientNetModel.from_pretrained(\"google/efficientnet-b7\")"
]
},
{
"cell_type": "code",
"execution_count": 41,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"input shape: torch.Size([1, 3, 600, 600])\n",
"embedding shape: torch.Size([1, 640, 19, 19])\n",
"after reducing: torch.Size([1, 640])\n"
]
}
],
"source": [
"# prepare input image\n",
"inputs = image_processor(test_image, return_tensors='pt')\n",
"print('input shape: ', inputs['pixel_values'].shape)\n",
"\n",
"# inference \n",
"with torch.no_grad():\n",
" outputs = model(**inputs, output_hidden_states=True)\n",
" \n",
"embedding = outputs.hidden_states[-1]\n",
"print('embedding shape: ', embedding.shape)\n",
"\n",
"embedding = torch.mean(embedding, dim=[2,3])\n",
"print('after reducing: ', embedding.shape)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## ViT feature extraction"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Fast image processor class <class 'transformers.models.vit.image_processing_vit_fast.ViTImageProcessorFast'> is available for this model. Using slow image processor class. To use the fast image processor class set `use_fast=True`.\n"
]
}
],
"source": [
"# load pre-trained image processor for ViT-large and model weight\n",
"image_processor = AutoImageProcessor.from_pretrained(\"google/vit-large-patch16-224-in21k\")\n",
"model = ViTModel.from_pretrained(\"google/vit-large-patch16-224-in21k\")"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"input shape: torch.Size([1, 3, 224, 224])\n",
"embedding shape: torch.Size([1, 1024])\n"
]
}
],
"source": [
"# prepare input image\n",
"inputs = image_processor(test_image, return_tensors='pt')\n",
"print('input shape: ', inputs['pixel_values'].shape)\n",
"\n",
"with torch.no_grad():\n",
" outputs = model(**inputs)\n",
" \n",
"embedding = outputs.last_hidden_state\n",
"embedding = embedding[:, 0, :].squeeze(1)\n",
"print('embedding shape: ', embedding.shape)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## DINO-v2 feature extraction"
]
},
{
"cell_type": "code",
"execution_count": 49,
"metadata": {},
"outputs": [],
"source": [
"# load pre-trained image processor for DINO-v2 and model weight\n",
"image_processor = AutoImageProcessor.from_pretrained('facebook/dinov2-base')\n",
"model = AutoModel.from_pretrained('facebook/dinov2-base')"
]
},
{
"cell_type": "code",
"execution_count": 50,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"input shape: torch.Size([1, 3, 224, 224])\n",
"embedding shape: torch.Size([1, 768])\n"
]
}
],
"source": [
"# prepare input image\n",
"inputs = image_processor(images=test_image, return_tensors='pt')\n",
"print('input shape: ', inputs['pixel_values'].shape)\n",
"\n",
"with torch.no_grad():\n",
" outputs = model(**inputs)\n",
" \n",
"embedding = outputs.last_hidden_state\n",
"embedding = embedding[:, 0, :].squeeze(1)\n",
"print('embedding shape: ', embedding.shape)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## CLIP feature extraction"
]
},
{
"cell_type": "code",
"execution_count": 55,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/opt/conda/envs/transformers-env/lib/python3.10/site-packages/transformers/tokenization_utils_base.py:1617: FutureWarning: `clean_up_tokenization_spaces` was not set. It will be set to `True` by default. This behavior will be deprecated in transformers v4.45, and will be then set to `False` by default. For more details check this issue: https://github.com/huggingface/transformers/issues/31884\n",
" warnings.warn(\n"
]
}
],
"source": [
"# load pre-trained image processor for CLIP and model weight\n",
"image_processor = CLIPProcessor.from_pretrained(\"openai/clip-vit-base-patch32\")\n",
"model = CLIPModel.from_pretrained(\"openai/clip-vit-base-patch32\")"
]
},
{
"cell_type": "code",
"execution_count": 60,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"input shape: torch.Size([1, 3, 224, 224])\n",
"embedding shape: torch.Size([1, 512])\n"
]
}
],
"source": [
"# prepare input image\n",
"inputs = image_processor(images=test_image, return_tensors='pt', padding=True)\n",
"print('input shape: ', inputs['pixel_values'].shape)\n",
"\n",
"with torch.no_grad():\n",
" outputs = model.get_image_features(**inputs)\n",
" \n",
"print('embedding shape: ', outputs.shape)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## BLIP-2 feature extraction"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/opt/conda/envs/transformers-env/lib/python3.10/site-packages/transformers/tokenization_utils_base.py:1617: FutureWarning: `clean_up_tokenization_spaces` was not set. It will be set to `True` by default. This behavior will be deprecated in transformers v4.45, and will be then set to `False` by default. For more details check this issue: https://github.com/huggingface/transformers/issues/31884\n",
" warnings.warn(\n",
"Loading checkpoint shards: 100%|██████████| 2/2 [00:02<00:00, 1.20s/it]\n"
]
}
],
"source": [
"image_processor = Blip2Processor.from_pretrained(\"Salesforce/blip2-opt-2.7b\")\n",
"model = Blip2Model.from_pretrained(\"Salesforce/blip2-opt-2.7b\", torch_dtype=torch.float16)"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"input shape: torch.Size([1, 3, 224, 224])\n",
"after reducing: torch.Size([1, 768])\n"
]
}
],
"source": [
"# prepare input image\n",
"inputs = image_processor(images=test_image, return_tensors='pt', padding=True)\n",
"print('input shape: ', inputs['pixel_values'].shape)\n",
"\n",
"with torch.no_grad():\n",
" outputs = model.get_qformer_features(**inputs)\n",
" outputs = outputs.last_hidden_state\n",
" \n",
"embedding = torch.mean(outputs, dim=1).squeeze(1)\n",
"print('after reducing: ', embedding.shape)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Image similarity search"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### custom pipeline for EfficientNet"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"class EfficientNetPipeline(Pipeline):\n",
" \n",
" def _sanitize_parameters(self, **kwargs):\n",
" return {}, {}, {}\n",
"\n",
" def preprocess(self, image):\n",
" image = load_image(image)\n",
" model_inputs = self.image_processor(images=image, return_tensors=\"pt\")\n",
" \n",
" return model_inputs\n",
"\n",
" def _forward(self, model_inputs):\n",
" with torch.no_grad():\n",
" outputs = self.model(**model_inputs, output_hidden_states=True)\n",
" \n",
" return outputs\n",
"\n",
" def postprocess(self, model_outputs):\n",
" embedding = model_outputs.hidden_states[-1]\n",
" embedding = torch.mean(embedding, dim=[2,3])\n",
" \n",
" return embedding"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### custom pipeline for ViT"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"class ViTPipeline(Pipeline):\n",
" def _sanitize_parameters(self, **kwargs):\n",
" return {}, {}, {}\n",
"\n",
" def preprocess(self, image):\n",
" image = load_image(image)\n",
" model_inputs = self.image_processor(images=image, return_tensors=\"pt\")\n",
" return model_inputs\n",
"\n",
" def _forward(self, model_inputs):\n",
" with torch.no_grad():\n",
" outputs = self.model(**model_inputs)\n",
" \n",
" return outputs\n",
"\n",
" def postprocess(self, model_outputs):\n",
" embedding = model_outputs.last_hidden_state\n",
" embedding = embedding[:, 0, :].squeeze(1)\n",
" \n",
" return embedding\n",
" "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### custom pipeline for DINO-v2"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"class DINOv2Pipeline(Pipeline):\n",
" def _sanitize_parameters(self, **kwargs):\n",
" return {}, {}, {}\n",
"\n",
" def preprocess(self, image):\n",
" image = load_image(image)\n",
" model_inputs = self.image_processor(images=image, return_tensors=\"pt\")\n",
" \n",
" return model_inputs\n",
"\n",
" def _forward(self, model_inputs):\n",
" with torch.no_grad():\n",
" outputs = self.model(**model_inputs)\n",
" \n",
" return outputs\n",
"\n",
" def postprocess(self, model_outputs):\n",
" embedding = model_outputs.last_hidden_state\n",
" embedding = embedding[:, 0, :].squeeze(1)\n",
" \n",
" return embedding"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### custom pipeline for CLIP"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"class CLIPPipeline(Pipeline):\n",
" def _sanitize_parameters(self, **kwargs):\n",
" return {}, {}, {}\n",
"\n",
" def preprocess(self, image):\n",
" image = load_image(image)\n",
" model_inputs = self.image_processor(images=image, return_tensors=\"pt\")\n",
" \n",
" return model_inputs\n",
"\n",
" def _forward(self, model_inputs):\n",
" with torch.no_grad():\n",
" outputs = self.model.get_image_features(**model_inputs)\n",
" \n",
" return outputs\n",
"\n",
" def postprocess(self, model_outputs):\n",
" \n",
" return model_outputs\n",
" "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### custom pipeline for BLIP2"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"class BLIP2Pipeline(Pipeline):\n",
" def _sanitize_parameters(self, **kwargs):\n",
" return {}, {}, {}\n",
"\n",
" def preprocess(self, image):\n",
" image = load_image(image)\n",
" model_inputs = self.image_processor(images=image, return_tensors=\"pt\")\n",
" \n",
" return model_inputs\n",
"\n",
" def _forward(self, model_inputs):\n",
" with torch.no_grad():\n",
" outputs = self.model.get_qformer_features(**model_inputs)\n",
" \n",
" return outputs\n",
"\n",
" def postprocess(self, model_outputs):\n",
" embedding = model_outputs.last_hidden_state\n",
" embedding = torch.mean(embedding, dim=1).squeeze(1)\n",
" \n",
" return embedding"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"def define_model(model_name: str):\n",
" if model_name == 'EfficientNet':\n",
" image_processor = AutoImageProcessor.from_pretrained(\"google/efficientnet-b7\")\n",
" model = EfficientNetModel.from_pretrained(\"google/efficientnet-b7\")\n",
" elif model_name == 'ViT':\n",
" image_processor = AutoImageProcessor.from_pretrained(\"google/vit-large-patch16-224-in21k\")\n",
" model = ViTModel.from_pretrained(\"google/vit-large-patch16-224-in21k\")\n",
" elif model_name == 'DINO-v2':\n",
" image_processor = AutoImageProcessor.from_pretrained('facebook/dinov2-base')\n",
" model = AutoModel.from_pretrained('facebook/dinov2-base')\n",
" elif model_name == 'CLIP':\n",
" image_processor = CLIPProcessor.from_pretrained(\"openai/clip-vit-base-patch32\")\n",
" model = CLIPModel.from_pretrained(\"openai/clip-vit-base-patch32\")\n",
" elif model_name == 'BLIP2':\n",
" image_processor = Blip2Processor.from_pretrained(\"Salesforce/blip2-opt-2.7b\")\n",
" model = Blip2Model.from_pretrained(\"Salesforce/blip2-opt-2.7b\", torch_dtype=torch.float16)\n",
" \n",
" return image_processor, model"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"def load_pipeline(model_name: str):\n",
" image_processor, model = define_model(model_name=model_name)\n",
" \n",
" if model_name == 'EfficientNet':\n",
" pipeline = EfficientNetPipeline(model=model, image_processor=image_processor, device=0)\n",
" elif model_name == 'ViT':\n",
" pipeline = ViTPipeline(model=model, image_processor=image_processor, device=0)\n",
" elif model_name == 'DINO-v2':\n",
" pipeline = DINOv2Pipeline(model=model, image_processor=image_processor, device=0)\n",
" elif model_name == 'CLIP':\n",
" pipeline = CLIPPipeline(model=model, image_processor=image_processor, device=0)\n",
" elif model_name == 'BLIP2':\n",
" pipeline = BLIP2Pipeline(model=model, image_processor=image_processor, device=0)\n",
" \n",
" return pipeline"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"def register_embeddings(embeddings):\n",
" vector_dim = embeddings.shape[1]\n",
" \n",
" index = faiss.IndexFlatIP(vector_dim)\n",
" faiss.normalize_L2(embeddings)\n",
" index.add(embeddings)\n",
" \n",
" return index\n",
"\n",
"def image_similarity_search(embeddings, index, image_name_list, model_name, result_dir, top_k=6):\n",
" \n",
" result_dict = {\n",
" 'top0_similar': [],\n",
" 'top1_similar': [],\n",
" 'top2_similar': [],\n",
" 'top3_similar': [],\n",
" 'top4_similar': [],\n",
" 'top5_similar': []\n",
" }\n",
" \n",
" for embed in embeddings:\n",
" embed = embed.reshape(1, -1)\n",
" faiss.normalize_L2(embed)\n",
" distances, ann = index.search(embed, k=top_k)\n",
"\n",
" for k in range(top_k):\n",
" idx = ann[0][k]\n",
" \n",
" result_dict[f'top{str(k)}_similar'].append(image_name_list[idx])\n",
"\n",
" df = pd.DataFrame.from_dict(result_dict)\n",
" df.to_csv(os.path.join(result_dir, f'{model_name}.csv'), index=None)"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"\n",
"def collect_data(dataset_dir):\n",
" files_list = []\n",
"\n",
" for root, _, files in os.walk(dataset_dir):\n",
" if len(files) > 0:\n",
" for f in files:\n",
" if 'jpg' in f:\n",
" # append an image file path\n",
" filepath = os.path.join(root, f)\n",
" files_list.append(filepath)\n",
"\n",
" return files_list"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [],
"source": [
"model_name = 'ViT'\n",
"batch_size = 16\n",
"\n",
"dataset_dir = '<Your dataset directory>'\n",
"result_dir = './results_test'\n",
"os.makedirs(result_dir, exist_ok=True)"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Fast image processor class <class 'transformers.models.vit.image_processing_vit_fast.ViTImageProcessorFast'> is available for this model. Using slow image processor class. To use the fast image processor class set `use_fast=True`.\n",
" 2%|▏ | 10/625 [00:11<07:02, 1.46it/s]You seem to be using the pipelines sequentially on GPU. In order to maximize efficiency please use a dataset\n",
"100%|██████████| 625/625 [06:37<00:00, 1.57it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"embedding shape: (10000, 1024)\n"
]
}
],
"source": [
"# load pipeline\n",
"pipeline = load_pipeline(model_name=model_name)\n",
"\n",
"# load dataset\n",
"dataset = collect_data(dataset_dir=dataset_dir)\n",
"\n",
"if os.path.exists(os.path.join(result_dir, f'{model_name}.npy')):\n",
" embeddings = np.load(os.path.join(result_dir, f'{model_name}.npy'))\n",
" print('embedding shape: ', embeddings.shape)\n",
"else:\n",
" # result embedding list\n",
" embeddings = []\n",
" \n",
" for idx in tqdm(range(0, len(dataset), batch_size)):\n",
" \n",
" data = dataset[idx: idx + batch_size]\n",
" \n",
" out = pipeline(data, batch_size=batch_size)\n",
" embeddings += [embed.detach().cpu() for embed in out]\n",
" \n",
" del out\n",
" torch.cuda.empty_cache()\n",
"\n",
" embeddings = torch.cat(embeddings, dim=0).detach().cpu().numpy()\n",
" print('embedding shape: ', embeddings.shape)\n",
" np.save(os.path.join(result_dir, f'{model_name}.npy'), embeddings)\n",
"\n",
"# similarity search by Faiss\n",
"embeddings = embeddings.astype(np.float32)\n",
"index = register_embeddings(embeddings)\n",
"image_similarity_search(embeddings=embeddings, index=index, image_name_list=dataset, model_name=model_name, result_dir=result_dir)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "transformers-env",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.15"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment