Last active
April 10, 2023 09:26
-
-
Save kitsamho/71ea2dbe3f98fc78b36a177fb0c63f21 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| def image_clustering_loop(df_image_embeddings, processor, model): | |
| """ | |
| Users can input a URL to an image (e.g. their LinkedIn photo or any link that links directly to an image (e.g. jpeg, .png) | |
| The app uses CLIP to generate image embeddings for this user-inputted image | |
| Pre-computed rock archive images are concatenated with this new user embedding | |
| The UMAP algorithm is used for dimensionality reduction across the concatenated embedding space | |
| The resulting markers/clusters are displayed as a scatter plot | |
| Each marker in the scatter plot represents an image where similar images are grouped together | |
| Args: | |
| df_image_embeddings (pandas.DataFrame): DataFrame containing the image embeddings generated by CLIP | |
| processor : CLIP processor | |
| model : A pre-trained CLIP model. | |
| Returns: | |
| None | |
| """ | |
| # Prompt the user to input their LinkedIn profile photo | |
| user_image_url = st.text_input("Paste your Linkedin profile photo here and see what rock star you most likely resemble"\ | |
| " (you are the bigger orange marker)", "https://media.licdn.com/dms/image/C4E03AQH5HLLr9gqm9Q/profile-displayphoto-shrink_200_200/0/1548859027859?e=1686182400&v=beta&t=3l5MRUItcun3pZePOppk4daPG6J3Hu1S5qNMIC0GlyA") | |
| # Add some empty lines for visual separation | |
| st.markdown('#') | |
| st.markdown('#') | |
| # Load the user's image and generate its embedding | |
| user_image = Image.open(requests.get(user_image_url, stream=True).raw).convert('L') | |
| # Create a dataframe with the user's image embedding | |
| user_image_embedding = calculate_image_features(user_image, processor, model, normalise=False).detach().numpy()[0] | |
| # Merge the user's image embedding with the pre-generated embeddings | |
| df_user_image_embedding = get_user_input_dataframe(user_image_url, user_image_embedding) | |
| df_merged = concatenate_dataframes(df_image_embeddings, df_user_image_embedding) | |
| # Reduce dimensionality of embeddings to 2 | |
| df_plot = get_umap_dataframe(umap.UMAP(), df_merged.image_clip_rep.values).join(df_merged) | |
| # Set the size of the points in the scatter plot | |
| df_plot['size'] = [1.5] * df_plot.shape[0] | |
| df_plot.loc[1250, 'size'] = 30 | |
| # Plot the markers using altair | |
| plot = plot_image_clusters(df_plot) | |
| st.altair_chart(plot) | |
| return |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment