Created
September 10, 2023 17:21
-
-
Save alexpaden/517aaa2195b9bbc15a979d09202bb1c0 to your computer and use it in GitHub Desktop.
Binary Meme Classifier
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from transformers import AutoModelForImageClassification, AutoProcessor | |
import torch | |
import requests | |
from PIL import Image | |
import numpy as np | |
from io import BytesIO | |
def download_image(url): | |
response = requests.get(url) | |
print(f"Status Code: {response.status_code}") | |
print(f"Content-Type: {response.headers.get('Content-Type')}") | |
if response.status_code == 200 and response.headers.get('Content-Type').startswith('image/'): # Success | |
return Image.open(BytesIO(response.content)) | |
else: | |
print(f"Failed to download image from {url}") | |
return None | |
def classify_meme(image_url, model_name="Hrishikesh332/autotrain-meme-classification-42897109437"): | |
# Download the image from the URL | |
image = download_image(image_url) | |
if image is None: | |
return "Failed to download image." | |
# Convert the image to "RGB" if it's not | |
if image.mode != 'RGB': | |
image = image.convert('RGB') | |
# Convert the PIL image to a NumPy array | |
image_np = np.array(image) | |
# Initialize the processor and model | |
processor = AutoProcessor.from_pretrained(model_name) | |
model = AutoModelForImageClassification.from_pretrained(model_name) | |
# Preprocess the image and make prediction | |
inputs = processor(images=image_np, return_tensors="pt") | |
outputs = model(**inputs) | |
logits = outputs.logits | |
predicted_class_idx = torch.argmax(logits, dim=1).item() | |
# Map the predicted index to the label | |
label = "Meme" if predicted_class_idx == 0 else "Not Meme" | |
return label | |
if __name__ == "__main__": | |
image_url = input("Enter the image URL: ") | |
label = classify_meme(image_url) | |
print(f"The image is classified as: {label}") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment