Created
January 13, 2025 01:34
-
-
Save jamessdixon/328f62a95843b68118d07670d07c7f53 to your computer and use it in GitHub Desktop.
document_classification
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"%pip install transformers\n", | |
"%pip install Pillow" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"/Users/jamiedixon/Documents/XXX/.venv/lib/python3.12/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", | |
" from .autonotebook import tqdm as notebook_tqdm\n" | |
] | |
} | |
], | |
"source": [ | |
"import os\n", | |
"from transformers import CLIPProcessor, CLIPModel\n", | |
"from PIL import Image" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Image: 8.png\n", | |
"Predicted class: A driver's license\n", | |
"Confidence: 0.57\n", | |
"------------------------------\n", | |
"Image: 9.png\n", | |
"Predicted class: A student ID\n", | |
"Confidence: 0.84\n", | |
"------------------------------\n", | |
"Image: 11.png\n", | |
"Predicted class: A student ID\n", | |
"Confidence: 0.95\n", | |
"------------------------------\n", | |
"Image: 10.png\n", | |
"Predicted class: A student ID\n", | |
"Confidence: 0.79\n", | |
"------------------------------\n", | |
"Image: 4.png\n", | |
"Predicted class: A passport\n", | |
"Confidence: 1.00\n", | |
"------------------------------\n", | |
"Image: 5.png\n", | |
"Predicted class: A passport\n", | |
"Confidence: 1.00\n", | |
"------------------------------\n", | |
"Image: 7.png\n", | |
"Predicted class: A passport\n", | |
"Confidence: 0.99\n", | |
"------------------------------\n", | |
"Image: 6.png\n", | |
"Predicted class: A passport\n", | |
"Confidence: 1.00\n", | |
"------------------------------\n", | |
"Image: 2.png\n", | |
"Predicted class: A driver's license\n", | |
"Confidence: 1.00\n", | |
"------------------------------\n", | |
"Image: 3.png\n", | |
"Predicted class: A driver's license\n", | |
"Confidence: 0.95\n", | |
"------------------------------\n", | |
"Image: 1.png\n", | |
"Predicted class: A driver's license\n", | |
"Confidence: 0.66\n", | |
"------------------------------\n", | |
"Image: 0.png\n", | |
"Predicted class: A driver's license\n", | |
"Confidence: 0.62\n", | |
"------------------------------\n" | |
] | |
} | |
], | |
"source": [ | |
"model = CLIPModel.from_pretrained(\"openai/clip-vit-base-patch32\")\n", | |
"processor = CLIPProcessor.from_pretrained(\"openai/clip-vit-base-patch32\")\n", | |
"\n", | |
"image_folder = \"data\"\n", | |
"\n", | |
"descriptions = [\"A driver's license\", \"A passport\", \"A student ID\"]\n", | |
"\n", | |
"for filename in os.listdir(image_folder):\n", | |
" if filename.lower().endswith((\".jpg\", \".jpeg\", \".png\")):\n", | |
" image_path = os.path.join(image_folder, filename)\n", | |
" image = Image.open(image_path)\n", | |
" inputs = processor(text=descriptions, images=image, return_tensors=\"pt\", padding=True)\n", | |
" outputs = model(**inputs)\n", | |
" logits_per_image = outputs.logits_per_image \n", | |
" probs = logits_per_image.softmax(dim=1) \n", | |
"\n", | |
" predicted_class = descriptions[probs.argmax()]\n", | |
" confidence = probs.max().item()\n", | |
"\n", | |
" print(f\"Image: {filename}\")\n", | |
" print(f\"Predicted class: {predicted_class}\")\n", | |
" print(f\"Confidence: {confidence:.2f}\")\n", | |
" print(\"-\" * 30)\n" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"from transformers import CLIPModel, CLIPProcessor\n", | |
"from PIL import Image\n", | |
"import os\n", | |
"import torch" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 8, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"Some weights of RobertaForSequenceClassification were not initialized from the model checkpoint at FacebookAI/roberta-base and are newly initialized: ['classifier.dense.bias', 'classifier.dense.weight', 'classifier.out_proj.bias', 'classifier.out_proj.weight']\n", | |
"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Image: 8.png\n", | |
"Predicted class: A driver's license\n", | |
"Confidence: 0.57\n", | |
"------------------------------\n", | |
"Second Model Prediction: 0\n", | |
"Second Model Confidence: 0.56\n", | |
"==============================\n", | |
"Image: 9.png\n", | |
"Predicted class: A student ID\n", | |
"Confidence: 0.84\n", | |
"------------------------------\n", | |
"Second Model Prediction: 0\n", | |
"Second Model Confidence: 0.56\n", | |
"==============================\n", | |
"Image: 11.png\n", | |
"Predicted class: A student ID\n", | |
"Confidence: 0.95\n", | |
"------------------------------\n", | |
"Second Model Prediction: 0\n", | |
"Second Model Confidence: 0.56\n", | |
"==============================\n", | |
"Image: 10.png\n", | |
"Predicted class: A student ID\n", | |
"Confidence: 0.79\n", | |
"------------------------------\n", | |
"Second Model Prediction: 0\n", | |
"Second Model Confidence: 0.56\n", | |
"==============================\n", | |
"Image: 4.png\n", | |
"Predicted class: A passport\n", | |
"Confidence: 1.00\n", | |
"------------------------------\n", | |
"Second Model Prediction: 0\n", | |
"Second Model Confidence: 0.56\n", | |
"==============================\n", | |
"Image: 5.png\n", | |
"Predicted class: A passport\n", | |
"Confidence: 1.00\n", | |
"------------------------------\n", | |
"Second Model Prediction: 0\n", | |
"Second Model Confidence: 0.56\n", | |
"==============================\n", | |
"Image: 7.png\n", | |
"Predicted class: A passport\n", | |
"Confidence: 0.99\n", | |
"------------------------------\n", | |
"Second Model Prediction: 0\n", | |
"Second Model Confidence: 0.56\n", | |
"==============================\n", | |
"Image: 6.png\n", | |
"Predicted class: A passport\n", | |
"Confidence: 1.00\n", | |
"------------------------------\n", | |
"Second Model Prediction: 0\n", | |
"Second Model Confidence: 0.56\n", | |
"==============================\n", | |
"Image: 2.png\n", | |
"Predicted class: A driver's license\n", | |
"Confidence: 1.00\n", | |
"------------------------------\n", | |
"Second Model Prediction: 0\n", | |
"Second Model Confidence: 0.56\n", | |
"==============================\n", | |
"Image: 3.png\n", | |
"Predicted class: A driver's license\n", | |
"Confidence: 0.95\n", | |
"------------------------------\n", | |
"Second Model Prediction: 0\n", | |
"Second Model Confidence: 0.56\n", | |
"==============================\n", | |
"Image: 1.png\n", | |
"Predicted class: A driver's license\n", | |
"Confidence: 0.66\n", | |
"------------------------------\n", | |
"Second Model Prediction: 0\n", | |
"Second Model Confidence: 0.56\n", | |
"==============================\n", | |
"Image: 0.png\n", | |
"Predicted class: A driver's license\n", | |
"Confidence: 0.62\n", | |
"------------------------------\n", | |
"Second Model Prediction: 0\n", | |
"Second Model Confidence: 0.56\n", | |
"==============================\n" | |
] | |
} | |
], | |
"source": [ | |
"#1st model\n", | |
"clip_model = CLIPModel.from_pretrained(\"openai/clip-vit-base-patch32\")\n", | |
"clip_processor = CLIPProcessor.from_pretrained(\"openai/clip-vit-base-patch32\")\n", | |
"\n", | |
"#2nd model\n", | |
"from transformers import AutoModelForSequenceClassification, AutoTokenizer\n", | |
"second_model = AutoModelForSequenceClassification.from_pretrained(\"FacebookAI/roberta-base\")\n", | |
"second_tokenizer = AutoTokenizer.from_pretrained(\"FacebookAI/roberta-base\")\n", | |
"\n", | |
"image_folder = \"data\"\n", | |
"descriptions = [\"A driver's license\", \"A passport\", \"A student ID\"]\n", | |
"\n", | |
"for filename in os.listdir(image_folder):\n", | |
" if filename.lower().endswith((\".jpg\", \".jpeg\", \".png\")):\n", | |
" # Process the image with CLIP\n", | |
" image_path = os.path.join(image_folder, filename)\n", | |
" image = Image.open(image_path)\n", | |
" inputs = clip_processor(text=descriptions, images=image, return_tensors=\"pt\", padding=True)\n", | |
" outputs = clip_model(**inputs)\n", | |
" logits_per_image = outputs.logits_per_image \n", | |
" probs = logits_per_image.softmax(dim=1) \n", | |
"\n", | |
" # Get predicted class and confidence\n", | |
" predicted_class = descriptions[probs.argmax()]\n", | |
" confidence = probs.max().item()\n", | |
"\n", | |
" print(f\"Image: {filename}\")\n", | |
" print(f\"Predicted class: {predicted_class}\")\n", | |
" print(f\"Confidence: {confidence:.2f}\")\n", | |
" print(\"-\" * 30)\n", | |
"\n", | |
" # Prepare input for the second model\n", | |
" combined_input = f\"Image is predicted as {predicted_class} with confidence {confidence:.2f}.\"\n", | |
" encoded_input = second_tokenizer(combined_input, return_tensors=\"pt\", padding=True)\n", | |
"\n", | |
" # Send data to the second model\n", | |
" second_model_outputs = second_model(**encoded_input)\n", | |
" second_model_probs = torch.nn.functional.softmax(second_model_outputs.logits, dim=1)\n", | |
" second_model_confidence = second_model_probs.max().item()\n", | |
" second_model_prediction = second_model_probs.argmax().item()\n", | |
"\n", | |
" print(f\"Second Model Prediction: {second_model_prediction}\")\n", | |
" print(f\"Second Model Confidence: {second_model_confidence:.2f}\")\n", | |
" print(\"=\" * 30)\n" | |
] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": ".venv", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.12.7" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment