Created
June 19, 2024 20:29
-
-
Save tspannhw/8e2ec1293c1cff1edaefbf7fde54f47a to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| { | |
| "cells": [ | |
| { | |
| "cell_type": "code", | |
| "execution_count": 1, | |
| "id": "fc64742b-d1ba-4926-9018-53800cae9e05", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "Requirement already satisfied: boto3 in ./milvusvenv/lib/python3.12/site-packages (1.34.129)\n", | |
| "Requirement already satisfied: botocore<1.35.0,>=1.34.129 in ./milvusvenv/lib/python3.12/site-packages (from boto3) (1.34.129)\n", | |
| "Requirement already satisfied: jmespath<2.0.0,>=0.7.1 in ./milvusvenv/lib/python3.12/site-packages (from boto3) (1.0.1)\n", | |
| "Requirement already satisfied: s3transfer<0.11.0,>=0.10.0 in ./milvusvenv/lib/python3.12/site-packages (from boto3) (0.10.1)\n", | |
| "Requirement already satisfied: python-dateutil<3.0.0,>=2.1 in ./milvusvenv/lib/python3.12/site-packages (from botocore<1.35.0,>=1.34.129->boto3) (2.9.0.post0)\n", | |
| "Requirement already satisfied: urllib3!=2.2.0,<3,>=1.25.4 in ./milvusvenv/lib/python3.12/site-packages (from botocore<1.35.0,>=1.34.129->boto3) (2.2.1)\n", | |
| "Requirement already satisfied: six>=1.5 in ./milvusvenv/lib/python3.12/site-packages (from python-dateutil<3.0.0,>=2.1->botocore<1.35.0,>=1.34.129->boto3) (1.16.0)\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "!pip install boto3" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 5, | |
| "id": "31fd4a20-50ce-4150-9a53-7b862dd2f9de", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "from __future__ import print_function\n", | |
| "import requests\n", | |
| "import sys\n", | |
| "import io\n", | |
| "import json\n", | |
| "import shutil\n", | |
| "import sys\n", | |
| "import datetime\n", | |
| "import subprocess\n", | |
| "import sys\n", | |
| "import os\n", | |
| "import math\n", | |
| "import base64\n", | |
| "from time import gmtime, strftime\n", | |
| "import random, string\n", | |
| "import time\n", | |
| "import psutil\n", | |
| "import base64\n", | |
| "import uuid\n", | |
| "import socket\n", | |
| "import os\n", | |
| "from pymilvus import connections\n", | |
| "from pymilvus import utility\n", | |
| "from pymilvus import FieldSchema, CollectionSchema, DataType, Collection\n", | |
| "import torch\n", | |
| "from torchvision import transforms\n", | |
| "from PIL import Image\n", | |
| "import timm\n", | |
| "from sklearn.preprocessing import normalize\n", | |
| "from timm.data import resolve_data_config\n", | |
| "from timm.data.transforms_factory import create_transform\n", | |
| "from pymilvus import MilvusClient\n", | |
| "import os\n", | |
| "from IPython.display import display" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 6, | |
| "id": "4a56beb4-67b0-480f-ae1f-20bc6f7daaff", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "from __future__ import print_function\n", | |
| "import requests\n", | |
| "import sys\n", | |
| "import io\n", | |
| "import json\n", | |
| "import shutil\n", | |
| "import sys\n", | |
| "import datetime\n", | |
| "import subprocess\n", | |
| "import sys\n", | |
| "import os\n", | |
| "import math\n", | |
| "import base64\n", | |
| "from time import gmtime, strftime\n", | |
| "import random, string\n", | |
| "import time\n", | |
| "import psutil\n", | |
| "import base64\n", | |
| "import uuid\n", | |
| "import socket\n", | |
| "import os\n", | |
| "from pymilvus import connections\n", | |
| "from pymilvus import utility\n", | |
| "from pymilvus import FieldSchema, CollectionSchema, DataType, Collection\n", | |
| "import torch\n", | |
| "from torchvision import transforms\n", | |
| "from PIL import Image\n", | |
| "import timm\n", | |
| "from sklearn.preprocessing import normalize\n", | |
| "from timm.data import resolve_data_config\n", | |
| "from timm.data.transforms_factory import create_transform\n", | |
| "from pymilvus import MilvusClient\n", | |
| "import os\n", | |
| "from IPython.display import display" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 8, | |
| "id": "1cb009b2-ef4a-4716-a962-6021ff9e6199", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "# -----------------------------------------------------------------------------\n", | |
| "\n", | |
| "class FeatureExtractor:\n", | |
| " def __init__(self, modelname):\n", | |
| " # Load the pre-trained model\n", | |
| " self.model = timm.create_model(\n", | |
| " modelname, pretrained=True, num_classes=0, global_pool=\"avg\"\n", | |
| " )\n", | |
| " self.model.eval()\n", | |
| "\n", | |
| " # Get the input size required by the model\n", | |
| " self.input_size = self.model.default_cfg[\"input_size\"]\n", | |
| "\n", | |
| " config = resolve_data_config({}, model=modelname)\n", | |
| " # Get the preprocessing function provided by TIMM for the model\n", | |
| " self.preprocess = create_transform(**config)\n", | |
| "\n", | |
| " def __call__(self, imagepath):\n", | |
| " # Preprocess the input image\n", | |
| " input_image = Image.open(imagepath).convert(\"RGB\") # Convert to RGB if needed\n", | |
| " input_image = self.preprocess(input_image)\n", | |
| "\n", | |
| " # Convert the image to a PyTorch tensor and add a batch dimension\n", | |
| " input_tensor = input_image.unsqueeze(0)\n", | |
| "\n", | |
| " # Perform inference\n", | |
| " with torch.no_grad():\n", | |
| " output = self.model(input_tensor)\n", | |
| "\n", | |
| " # Extract the feature vector\n", | |
| " feature_vector = output.squeeze().numpy()\n", | |
| "\n", | |
| " return normalize(feature_vector.reshape(1, -1), norm=\"l2\").flatten()" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 9, | |
| "id": "525b6538-4e4c-4bf2-a849-d1433c0cf36a", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "extractor = FeatureExtractor(\"resnet34\")\n", | |
| "\n", | |
| "# -----------------------------------------------------------------------------\n", | |
| "# Constants - should be environment variables\n", | |
| "# -----------------------------------------------------------------------------\n", | |
| "DIMENSION = 512 \n", | |
| "MILVUS_URL = \"http://192.168.1.163:19530\" \n", | |
| "COLLECTION_NAME = \"pidetections\"\n", | |
| "BUCKET_NAME = \"images\"\n", | |
| "DOWNLOAD_DIR = \"/Users/timothyspann/Downloads/code/images/\"\n", | |
| "AWS_RESOURCE = \"s3\"\n", | |
| "S3_ENDPOINT_URL = \"http://192.168.1.163:9000\"\n", | |
| "AWS_ACCESS_KEY = \"minioadmin\" \n", | |
| "AWS_SECRET_ACCESS_KEY = \"minioadmin\"\n", | |
| "S3_SIGNATURE_VERSION = \"s3v4\"\n", | |
| "AWS_REGION_NAME = \"us-east-1\"\n", | |
| "S3_ERROR_MESSAGE = \"Download failed\"\n", | |
| "# -----------------------------------------------------------------------------" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 10, | |
| "id": "806e03fd-3ffb-4c51-8179-b79f5e1980bd", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "# -----------------------------------------------------------------------------\n", | |
| "# Connect to Milvus\n", | |
| "\n", | |
| "# Local Docker Server\n", | |
| "milvus_client = MilvusClient( uri=MILVUS_URL)\n", | |
| "# -----------------------------------------------------------------------------" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 12, | |
| "id": "8dbf975d-c093-4e99-aa38-b320557719b5", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "import os\n", | |
| "import boto3\n", | |
| "from botocore.client import Config\n", | |
| "\n", | |
| "# -----------------------------------------------------------------------------\n", | |
| "# Access Images on S3 Compatible Store - AWS S3 or Minio or ...\n", | |
| "# -----------------------------------------------------------------------------\n", | |
| "s3 = boto3.resource(AWS_RESOURCE,\n", | |
| " endpoint_url=S3_ENDPOINT_URL,\n", | |
| " aws_access_key_id=AWS_ACCESS_KEY,\n", | |
| " aws_secret_access_key=AWS_SECRET_ACCESS_KEY,\n", | |
| " config=Config(signature_version=S3_SIGNATURE_VERSION),\n", | |
| " region_name=AWS_REGION_NAME)\n", | |
| "\n", | |
| "bucket = s3.Bucket(BUCKET_NAME)\n", | |
| "\n", | |
| "# -----------------------------------------------------------------------------\n", | |
| "# Get last modified image\n", | |
| "# -----------------------------------------------------------------------------\n", | |
| "files = bucket.objects.filter()\n", | |
| "files = [obj.key for obj in sorted(files, key=lambda x: x.last_modified, \n", | |
| " reverse=True)]\n", | |
| "\n", | |
| "for imagename in files:\n", | |
| " query_image = imagename\n", | |
| " break\n", | |
| "\n", | |
| "search_image_name = DOWNLOAD_DIR + query_image\n", | |
| "\n", | |
| "try:\n", | |
| " s3.Bucket(BUCKET_NAME).download_file(query_image, search_image_name)\n", | |
| "except botocore.exceptions.ClientError as e:\n", | |
| " print(S3_ERROR_MESSAGE)\n", | |
| "\n", | |
| "# -----------------------------------------------------------------------------\n", | |
| "# Search Milvus for that vector and filter by a label\n", | |
| "# -----------------------------------------------------------------------------\n", | |
| "results = milvus_client.search(\n", | |
| " COLLECTION_NAME,\n", | |
| " data=[extractor(search_image_name)],\n", | |
| " filter='label in [\"keyboard\"]',\n", | |
| " output_fields=[\"label\", \"confidence\", \"id\", \"s3path\", \"filename\"],\n", | |
| " search_params={\"metric_type\": \"COSINE\"},\n", | |
| " limit=5\n", | |
| ")\n", | |
| "\n", | |
| "# -----------------------------------------------------------------------------\n", | |
| "# Iterate through last five results and display metadata and image\n", | |
| "# -----------------------------------------------------------------------------\n", | |
| "for result in results:\n", | |
| " for hit in result[:5]:\n", | |
| " label = hit[\"entity\"][\"label\"]\n", | |
| " confidence = hit[\"entity\"][\"confidence\"]\n", | |
| " filename = hit[\"entity\"][\"filename\"]\n", | |
| " s3path = hit[\"entity\"][\"s3path\"]\n", | |
| " try:\n", | |
| " s3.Bucket(BUCKET_NAME).download_file(filename, DOWNLOAD_DIR + filename)\n", | |
| " except botocore.exceptions.ClientError as e:\n", | |
| " print(S3_ERROR_MESSAGE)\n", | |
| " print(f\"Detection: {label} {confidence:.2f} for {filename} from {s3path}\" )\n", | |
| " img = Image.open(DOWNLOAD_DIR + filename)\n", | |
| " display(img) \n", | |
| " \n", | |
| "# Enhancement: we could also post this to slack or discord" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "id": "9813a545-7ccd-46ad-9e6b-b2f65fd68067", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [] | |
| } | |
| ], | |
| "metadata": { | |
| "kernelspec": { | |
| "display_name": "Python 3 (ipykernel)", | |
| "language": "python", | |
| "name": "python3" | |
| }, | |
| "language_info": { | |
| "codemirror_mode": { | |
| "name": "ipython", | |
| "version": 3 | |
| }, | |
| "file_extension": ".py", | |
| "mimetype": "text/x-python", | |
| "name": "python", | |
| "nbconvert_exporter": "python", | |
| "pygments_lexer": "ipython3", | |
| "version": "3.12.3" | |
| } | |
| }, | |
| "nbformat": 4, | |
| "nbformat_minor": 5 | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment