Skip to content

Instantly share code, notes, and snippets.

@jamessdixon
Created January 13, 2025 01:34
Show Gist options
  • Save jamessdixon/328f62a95843b68118d07670d07c7f53 to your computer and use it in GitHub Desktop.
Save jamessdixon/328f62a95843b68118d07670d07c7f53 to your computer and use it in GitHub Desktop.
document_classification
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"%pip install transformers\n",
"%pip install Pillow"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/Users/jamiedixon/Documents/XXX/.venv/lib/python3.12/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
" from .autonotebook import tqdm as notebook_tqdm\n"
]
}
],
"source": [
"import os\n",
"from transformers import CLIPProcessor, CLIPModel\n",
"from PIL import Image"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Image: 8.png\n",
"Predicted class: A driver's license\n",
"Confidence: 0.57\n",
"------------------------------\n",
"Image: 9.png\n",
"Predicted class: A student ID\n",
"Confidence: 0.84\n",
"------------------------------\n",
"Image: 11.png\n",
"Predicted class: A student ID\n",
"Confidence: 0.95\n",
"------------------------------\n",
"Image: 10.png\n",
"Predicted class: A student ID\n",
"Confidence: 0.79\n",
"------------------------------\n",
"Image: 4.png\n",
"Predicted class: A passport\n",
"Confidence: 1.00\n",
"------------------------------\n",
"Image: 5.png\n",
"Predicted class: A passport\n",
"Confidence: 1.00\n",
"------------------------------\n",
"Image: 7.png\n",
"Predicted class: A passport\n",
"Confidence: 0.99\n",
"------------------------------\n",
"Image: 6.png\n",
"Predicted class: A passport\n",
"Confidence: 1.00\n",
"------------------------------\n",
"Image: 2.png\n",
"Predicted class: A driver's license\n",
"Confidence: 1.00\n",
"------------------------------\n",
"Image: 3.png\n",
"Predicted class: A driver's license\n",
"Confidence: 0.95\n",
"------------------------------\n",
"Image: 1.png\n",
"Predicted class: A driver's license\n",
"Confidence: 0.66\n",
"------------------------------\n",
"Image: 0.png\n",
"Predicted class: A driver's license\n",
"Confidence: 0.62\n",
"------------------------------\n"
]
}
],
"source": [
"model = CLIPModel.from_pretrained(\"openai/clip-vit-base-patch32\")\n",
"processor = CLIPProcessor.from_pretrained(\"openai/clip-vit-base-patch32\")\n",
"\n",
"image_folder = \"data\"\n",
"\n",
"descriptions = [\"A driver's license\", \"A passport\", \"A student ID\"]\n",
"\n",
"for filename in os.listdir(image_folder):\n",
" if filename.lower().endswith((\".jpg\", \".jpeg\", \".png\")):\n",
" image_path = os.path.join(image_folder, filename)\n",
" image = Image.open(image_path)\n",
" inputs = processor(text=descriptions, images=image, return_tensors=\"pt\", padding=True)\n",
" outputs = model(**inputs)\n",
" logits_per_image = outputs.logits_per_image \n",
" probs = logits_per_image.softmax(dim=1) \n",
"\n",
" predicted_class = descriptions[probs.argmax()]\n",
" confidence = probs.max().item()\n",
"\n",
" print(f\"Image: {filename}\")\n",
" print(f\"Predicted class: {predicted_class}\")\n",
" print(f\"Confidence: {confidence:.2f}\")\n",
" print(\"-\" * 30)\n"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"from transformers import CLIPModel, CLIPProcessor\n",
"from PIL import Image\n",
"import os\n",
"import torch"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Some weights of RobertaForSequenceClassification were not initialized from the model checkpoint at FacebookAI/roberta-base and are newly initialized: ['classifier.dense.bias', 'classifier.dense.weight', 'classifier.out_proj.bias', 'classifier.out_proj.weight']\n",
"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Image: 8.png\n",
"Predicted class: A driver's license\n",
"Confidence: 0.57\n",
"------------------------------\n",
"Second Model Prediction: 0\n",
"Second Model Confidence: 0.56\n",
"==============================\n",
"Image: 9.png\n",
"Predicted class: A student ID\n",
"Confidence: 0.84\n",
"------------------------------\n",
"Second Model Prediction: 0\n",
"Second Model Confidence: 0.56\n",
"==============================\n",
"Image: 11.png\n",
"Predicted class: A student ID\n",
"Confidence: 0.95\n",
"------------------------------\n",
"Second Model Prediction: 0\n",
"Second Model Confidence: 0.56\n",
"==============================\n",
"Image: 10.png\n",
"Predicted class: A student ID\n",
"Confidence: 0.79\n",
"------------------------------\n",
"Second Model Prediction: 0\n",
"Second Model Confidence: 0.56\n",
"==============================\n",
"Image: 4.png\n",
"Predicted class: A passport\n",
"Confidence: 1.00\n",
"------------------------------\n",
"Second Model Prediction: 0\n",
"Second Model Confidence: 0.56\n",
"==============================\n",
"Image: 5.png\n",
"Predicted class: A passport\n",
"Confidence: 1.00\n",
"------------------------------\n",
"Second Model Prediction: 0\n",
"Second Model Confidence: 0.56\n",
"==============================\n",
"Image: 7.png\n",
"Predicted class: A passport\n",
"Confidence: 0.99\n",
"------------------------------\n",
"Second Model Prediction: 0\n",
"Second Model Confidence: 0.56\n",
"==============================\n",
"Image: 6.png\n",
"Predicted class: A passport\n",
"Confidence: 1.00\n",
"------------------------------\n",
"Second Model Prediction: 0\n",
"Second Model Confidence: 0.56\n",
"==============================\n",
"Image: 2.png\n",
"Predicted class: A driver's license\n",
"Confidence: 1.00\n",
"------------------------------\n",
"Second Model Prediction: 0\n",
"Second Model Confidence: 0.56\n",
"==============================\n",
"Image: 3.png\n",
"Predicted class: A driver's license\n",
"Confidence: 0.95\n",
"------------------------------\n",
"Second Model Prediction: 0\n",
"Second Model Confidence: 0.56\n",
"==============================\n",
"Image: 1.png\n",
"Predicted class: A driver's license\n",
"Confidence: 0.66\n",
"------------------------------\n",
"Second Model Prediction: 0\n",
"Second Model Confidence: 0.56\n",
"==============================\n",
"Image: 0.png\n",
"Predicted class: A driver's license\n",
"Confidence: 0.62\n",
"------------------------------\n",
"Second Model Prediction: 0\n",
"Second Model Confidence: 0.56\n",
"==============================\n"
]
}
],
"source": [
"#1st model\n",
"clip_model = CLIPModel.from_pretrained(\"openai/clip-vit-base-patch32\")\n",
"clip_processor = CLIPProcessor.from_pretrained(\"openai/clip-vit-base-patch32\")\n",
"\n",
"#2nd model\n",
"from transformers import AutoModelForSequenceClassification, AutoTokenizer\n",
"second_model = AutoModelForSequenceClassification.from_pretrained(\"FacebookAI/roberta-base\")\n",
"second_tokenizer = AutoTokenizer.from_pretrained(\"FacebookAI/roberta-base\")\n",
"\n",
"image_folder = \"data\"\n",
"descriptions = [\"A driver's license\", \"A passport\", \"A student ID\"]\n",
"\n",
"for filename in os.listdir(image_folder):\n",
" if filename.lower().endswith((\".jpg\", \".jpeg\", \".png\")):\n",
" # Process the image with CLIP\n",
" image_path = os.path.join(image_folder, filename)\n",
" image = Image.open(image_path)\n",
" inputs = clip_processor(text=descriptions, images=image, return_tensors=\"pt\", padding=True)\n",
" outputs = clip_model(**inputs)\n",
" logits_per_image = outputs.logits_per_image \n",
" probs = logits_per_image.softmax(dim=1) \n",
"\n",
" # Get predicted class and confidence\n",
" predicted_class = descriptions[probs.argmax()]\n",
" confidence = probs.max().item()\n",
"\n",
" print(f\"Image: {filename}\")\n",
" print(f\"Predicted class: {predicted_class}\")\n",
" print(f\"Confidence: {confidence:.2f}\")\n",
" print(\"-\" * 30)\n",
"\n",
" # Prepare input for the second model\n",
" combined_input = f\"Image is predicted as {predicted_class} with confidence {confidence:.2f}.\"\n",
" encoded_input = second_tokenizer(combined_input, return_tensors=\"pt\", padding=True)\n",
"\n",
" # Send data to the second model\n",
" second_model_outputs = second_model(**encoded_input)\n",
" second_model_probs = torch.nn.functional.softmax(second_model_outputs.logits, dim=1)\n",
" second_model_confidence = second_model_probs.max().item()\n",
" second_model_prediction = second_model_probs.argmax().item()\n",
"\n",
" print(f\"Second Model Prediction: {second_model_prediction}\")\n",
" print(f\"Second Model Confidence: {second_model_confidence:.2f}\")\n",
" print(\"=\" * 30)\n"
]
}
],
"metadata": {
"kernelspec": {
"display_name": ".venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.7"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment