Skip to content

Instantly share code, notes, and snippets.

@JGalego
Last active August 27, 2024 00:38
Show Gist options
  • Select an option

  • Save JGalego/d9ffbe7fdd4dc72c07abb82ccdce8618 to your computer and use it in GitHub Desktop.

Select an option

Save JGalego/d9ffbe7fdd4dc72c07abb82ccdce8618 to your computer and use it in GitHub Desktop.
Exploring the modality gap with Amazon Bedrock
# pylint: disable=redefined-outer-name
"""
Exploring the modality gap with Amazon Bedrock
References:
https://jina.ai/news/the-what-and-why-of-text-image-modality-gap-in-clip-models/
"""
import base64
import json
from io import BytesIO
from itertools import combinations
import boto3
import seaborn as sns
import matplotlib.pyplot as plt
import plotly.graph_objs as go
from datasets import load_dataset
from numpy import dot
from numpy.linalg import norm
from PIL import Image
from tqdm import tqdm
from umap import UMAP
####################
# Helper functions #
####################
def cosine_similarity(a, b):
"""
Computes cosine similarity for two vectors.
"""
return dot(a, b)/(norm(a)*norm(b))
def generate_embeddings(
text: str = None,
image: Image = None,
model_id: str = "amazon.titan-embed-image-v1",
embed_dim: int = 1024):
"""
Generates embeddings for text and/or image using Amazon Titan Multimodal Embeddings.
"""
assert text or image, "Image and/or text required!"
# Initialize client
bedrock = boto3.client("bedrock-runtime")
# Initialize request body
body = {
'embeddingConfig': {
'outputEmbeddingLength': embed_dim
}
}
# Process text
if text:
body['inputText'] = text
# Process image
if image:
buffered = BytesIO()
image.save(buffered, format="JPEG")
base64_bytes = base64.b64encode(buffered.getvalue())
base64_string = base64_bytes.decode('utf-8')
body['inputImage'] = base64_string
# Make request
body = json.dumps(body)
response = bedrock.invoke_model(
body=body,
modelId=model_id,
accept='application/json',
contentType='application/json'
)
# Process response
response_body = json.loads(response.get('body').read())
return response_body['embedding']
def plot_projections(text_projs, image_projs, text_labels, image_labels):
"""
Generates a 2D or 3D plot of embedding projections.
"""
# Initialize settings and figure
settings = {
'text': {
'color': 'blue',
'opacity': 0.5,
'symbol': 'circle',
'size': 5,
},
'image': {
'color': 'orange',
'opacity': 0.5,
'symbol': 'square',
'size': 7,
}
}
fig = go.Figure()
# Text
x, y = zip(*text_projs)
trace = go.Scatter(
x=x, y=y,
mode='markers',
name="Text",
marker={
'color': settings['text']['color'],
'opacity': settings['text']['opacity'],
'symbol': settings['text']['symbol'],
'size': settings['text']['size'],
'line_width': 0
},
hoverinfo='text',
text=text_labels
)
fig.add_trace(trace)
# Image
x, y = zip(*image_projs)
trace = go.Scatter(
x=x, y=y,
mode='markers',
name="Image",
marker={
'color': settings['image']['color'],
'opacity': settings['image']['opacity'],
'symbol': settings['image']['symbol'],
'size': settings['image']['size'],
'line_width': 0
},
hoverinfo='text',
text=image_labels
)
fig.add_trace(trace)
fig.update_layout(
title={
'text': "UMAP Projection",
'x': 0.5,
'xanchor': 'center'
},
legend={
'x': 0.5,
'xanchor': "center",
'yanchor': "bottom",
'orientation': "h"
}
)
return fig
########
# Main #
########
print("Preparing dataset...")
ds = load_dataset("jxie/flickr8k", split="train")
ds = ds.shuffle(seed=42)
samples = ds.select(range(1000))
print("Processing samples...")
image_embeddings = []
text_embeddings = []
image2text = []
text2text = []
for sample in tqdm(samples, ascii="░▒█"):
i_embeddings = generate_embeddings(image=sample.pop('image'))
# Image2Text
caption_embeddings = []
for key in sample.keys():
t_embeddings = generate_embeddings(text=sample[key])
caption_embeddings.append(t_embeddings)
image2text.append(cosine_similarity(i_embeddings, t_embeddings))
# Text2Text
for t_embeddings_comb in combinations(caption_embeddings, 2):
text2text.append(cosine_similarity(t_embeddings_comb[0], t_embeddings_comb[1]))
# Store embeddings
image_embeddings.append(i_embeddings)
text_embeddings.extend(caption_embeddings)
###########################
# Image2Text vs Text2Text #
###########################
print("Plotting histogram...")
sns.histplot({
"Image2Text": image2text,
"Text2Text": text2text
})
plt.xlabel('Cosine Similarity')
plt.ylabel('Density')
plt.show()
###################
# UMAP Projection #
###################
print("Projecting embeddings...")
umap_f = UMAP(
random_state=42, n_components=2
).fit(text_embeddings + image_embeddings)
image_projs = umap_f.transform(image_embeddings)
text_projs = umap_f.transform(text_embeddings)
print("Plotting projections...")
# Static
plt.scatter(*zip(*text_projs), c='b', marker='x', label='Text Embeddings')
plt.scatter(*zip(*image_projs), c='r', marker='s', label='Image Embeddings')
plt.legend(loc='upper left')
plt.xlabel('UMAP Dimension 1')
plt.ylabel('UMAP Dimension 2')
plt.show()
# Interactive
print("Writing interactive plot...")
text_labels = []
for i in range(5):
text_labels += samples[f'caption_{i}']
image_labels = samples['image']
fig = plot_projections(text_projs, image_projs, text_labels, image_labels)
fig.write_html("umap_projections.html")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment