Skip to content

Instantly share code, notes, and snippets.

@kitsamho
Last active April 10, 2023 09:26
Show Gist options
  • Select an option

  • Save kitsamho/71ea2dbe3f98fc78b36a177fb0c63f21 to your computer and use it in GitHub Desktop.

Select an option

Save kitsamho/71ea2dbe3f98fc78b36a177fb0c63f21 to your computer and use it in GitHub Desktop.
def image_clustering_loop(df_image_embeddings, processor, model):
"""
Users can input a URL to an image (e.g. their LinkedIn photo or any link that links directly to an image (e.g. jpeg, .png)
The app uses CLIP to generate image embeddings for this user-inputted image
Pre-computed rock archive images are concatenated with this new user embedding
The UMAP algorithm is used for dimensionality reduction across the concatenated embedding space
The resulting markers/clusters are displayed as a scatter plot
Each marker in the scatter plot represents an image where similar images are grouped together
Args:
df_image_embeddings (pandas.DataFrame): DataFrame containing the image embeddings generated by CLIP
processor : CLIP processor
model : A pre-trained CLIP model.
Returns:
None
"""
# Prompt the user to input their LinkedIn profile photo
user_image_url = st.text_input("Paste your Linkedin profile photo here and see what rock star you most likely resemble"\
" (you are the bigger orange marker)", "https://media.licdn.com/dms/image/C4E03AQH5HLLr9gqm9Q/profile-displayphoto-shrink_200_200/0/1548859027859?e=1686182400&v=beta&t=3l5MRUItcun3pZePOppk4daPG6J3Hu1S5qNMIC0GlyA")
# Add some empty lines for visual separation
st.markdown('#')
st.markdown('#')
# Load the user's image and generate its embedding
user_image = Image.open(requests.get(user_image_url, stream=True).raw).convert('L')
# Create a dataframe with the user's image embedding
user_image_embedding = calculate_image_features(user_image, processor, model, normalise=False).detach().numpy()[0]
# Merge the user's image embedding with the pre-generated embeddings
df_user_image_embedding = get_user_input_dataframe(user_image_url, user_image_embedding)
df_merged = concatenate_dataframes(df_image_embeddings, df_user_image_embedding)
# Reduce dimensionality of embeddings to 2
df_plot = get_umap_dataframe(umap.UMAP(), df_merged.image_clip_rep.values).join(df_merged)
# Set the size of the points in the scatter plot
df_plot['size'] = [1.5] * df_plot.shape[0]
df_plot.loc[1250, 'size'] = 30
# Plot the markers using altair
plot = plot_image_clusters(df_plot)
st.altair_chart(plot)
return
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment