-
-
Save Glavin001/275bc82536d43cb3821b4ff90e487135 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import base64 | |
from io import BytesIO | |
from PIL import Image | |
import torch | |
import torch.nn as nn | |
import ollama | |
from diffusers import DiffusionPipeline, StableDiffusionPipeline | |
from safetensors.torch import load_file | |
from transformers import CLIPProcessor, CLIPModel, AutoImageProcessor, AutoModel | |
import numpy as np | |
device = torch.device('cuda' if torch.cuda.is_available() else "cpu") | |
diffusion_model = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", custom_pipeline="lpw_stable_diffusion", torch_dtype=torch.float16).to(device) | |
def generate_image(prompt, image_filename="output.png"): | |
# https://github.com/huggingface/diffusers/tree/main/examples/community#long-prompt-weighting-stable-diffusion | |
image = diffusion_model(prompt=prompt, width=512, height=512, max_embeddings_multiples=3).images[0] | |
image.save(image_filename) | |
return image | |
def generate_image_embedding_with_clip(image): | |
clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") | |
clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device) | |
with torch.no_grad(): | |
inputs = clip_processor(images=image, return_tensors="pt").to(device) | |
image_features = clip_model.get_image_features(**inputs) | |
return image_features[0] | |
def generate_image_embedding_with_dino(image): | |
dino_processor = AutoImageProcessor.from_pretrained('facebook/dinov2-base') | |
dino_model = AutoModel.from_pretrained('facebook/dinov2-base').to(device) | |
with torch.no_grad(): | |
inputs = dino_processor(images=image, return_tensors="pt").to(device) | |
outputs = dino_model(**inputs) | |
image_features = outputs.last_hidden_state.mean(dim=1) | |
return image_features[0] | |
def compare_captions(image, short_caption, long_caption): | |
short_caption_image = generate_image(short_caption, "short_caption_image.jpg") | |
long_caption_image = generate_image(long_caption, "long_caption_image.jpg") | |
image_embedding_dino = generate_image_embedding_with_dino(image) | |
short_caption_image_embedding_dino = generate_image_embedding_with_dino(short_caption_image) | |
long_caption_image_embedding_dino = generate_image_embedding_with_dino(long_caption_image) | |
short_score_dino = calc_cosine_similarity(image_embedding_dino, short_caption_image_embedding_dino) | |
long_score_dino = calc_cosine_similarity(image_embedding_dino, long_caption_image_embedding_dino) | |
print(short_score_dino, long_score_dino) | |
image_embedding_clip = generate_image_embedding_with_clip(image) | |
short_caption_image_embedding_clip = generate_image_embedding_with_clip(short_caption_image) | |
long_caption_image_embedding_clip = generate_image_embedding_with_clip(long_caption_image) | |
short_score_clip = calc_cosine_similarity(image_embedding_clip, short_caption_image_embedding_clip) | |
long_score_clip = calc_cosine_similarity(image_embedding_clip, long_caption_image_embedding_clip) | |
print(short_score_clip, long_score_clip) | |
def calc_cosine_similarity(embedding_1, embedding_2): | |
cos = nn.CosineSimilarity(dim=0) | |
sim = cos(embedding_1, embedding_2).item() | |
return (sim + 1) / 2 | |
if __name__ == "__main__": | |
image = Image.open("image.jpg") | |
short_caption = "how to build an industry for dollars" | |
long_caption = "In the image, there is a small black house with a green roof situated in a grassy area surrounded by trees. The house appears to be under construction or renovation, as there are various tools and materials visible around it, such as a hammer, nails, screws, and wood planks. The presence of these objects indicates that the house is being built or repaired, and the green roof adds a unique and eco-friendly feature to the structure." | |
compare_captions(image, short_caption, long_caption) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment