Skip to content

Instantly share code, notes, and snippets.

def text_clustering_loop(df_text_embeddings, tokeniser, model):
"""
Upon initialisation, pre-computed embeddings for c.1500 images from the conceptual captions dataset are loaded. These are already labelled as one of 7 domains: property, sports, food, musicians, vehicles, illustration and nature
Users are prompted to type a text input related to one of the domains
The app uses CLIP to generate image embeddings for this user-inputted text
The Pre-computed conceptual captions are concatenated with this new user embedding
The UMAP algorithm is used for dimensionality reduction across the concatenated embedding space
The resulting markers/clusters are displayed as a scatter plot
Each marker in the scatter plot represents a text string where similar texts are grouped together
def image_clustering_loop(df_image_embeddings, processor, model):
"""
Users can input a URL to an image (e.g. their LinkedIn photo or any link that links directly to an image (e.g. jpeg, .png)
The app uses CLIP to generate image embeddings for this user-inputted image
Pre-computed rock archive images are concatenated with this new user embedding
The UMAP algorithm is used for dimensionality reduction across the concatenated embedding space
The resulting markers/clusters are displayed as a scatter plot
Each marker in the scatter plot represents an image where similar images are grouped together
Args:
@kitsamho
kitsamho / text_classification.py
Last active April 10, 2023 08:37
text classification loop for classifying bbc headlines using CLIP
def text_classification_loop(bbc_headlines, tokeniser, model):
"""
Displays a random bbc headline from a list of scraped headlines, prompts the user to enter some contrasting labels
for the headline, and uses the pre-trained model to predict the probabilities of the provided labels.
Returns the predicted probabilities in a dataframe that get visualised as a bar plot
Args:
bbc_headlines (List[str]): A list of BBC headlines to use for classification.
model: A pre-trained CLIP model.
tokenizer: A tokenizer for the CLIP model.
@kitsamho
kitsamho / download_clip_model.py
Created April 8, 2023 15:07
Downloading the CLIP model, processor and tokeniser
from transformers import CLIPModel, AutoProcessor, AutoTokenizer
import streamlit as st
@st.cache_resource
def download_clip_model(clip_model='openai/clip-vit-base-patch32'):
"""
Load the CLIP model and its associated tokenizer and processor from a given pre-trained model.
Args:
@kitsamho
kitsamho / classify_images.py
Last active April 10, 2023 08:06
Function that classifies images using CLIP
def classify_images(text_inputs: list, images: list, processor, model, tokeniser):
"""
Calculates the similarity between the text inputs and images.
Args:
text_inputs (list): List of text inputs.
images (list): List of images urls
processor (transformers.AutoProcessor): The CLIP processor to use.
model (transformers.CLIPModel): The CLIP model to use.
tokenizer (transformers.AutoTokenizer): The CLIP tokenizer to use.
a = st.beta_expander('About (click to expand)')
a.write("This is an animated plot where each step is a week. If you want to explore the animation using other features \
wait until the initial animation has ended or skip through to the end before changing features otherwise you may \
see some odd behaviour in the animation. Plotly can be a little fickle like that.")
a.write("This is a Plotly chart so you can click on the legend to mask values if needed.")
# set up some columns for the interactive widgets - use a mid point to create some buffer between widgets
c1, c2 = st.beta_columns((2, 3))
a = st.beta_expander('About (click to expand)')
a.write("The crossplot analysis allows you to plot two features against one another with the marker sizes representing \
what we might consider as dependant features e.g. total deaths, total deaths per million, total vaccinations.")
a.write("DataFrame masking allows you to explore the data by continent or include all countries.")
a.write("Each feature's central tendency is represented by the dashed line on each axis so you can see where \
countries are positioned in terms of the distribution for each feature. These update when you change features.")
a.write('The heat map is a summary of the correlations between all features')
# DataFrame set up
def get_indexes(x):
# identifies indexes where no data points exist yet and we will want to fill with a value other than null
index_fill_1 = [i for i in range(x.index[0], x.dropna().index[0])]
# identifies indexes where there is data however some missing points exist and we will want to apply interpolation
index_interpolate = [i for i in range(x.dropna().index[0], x.index[-1])]
return index_fill_1, index_interpolate
def update_series(x):
def get_data(df,transform_cols):
""" This is the main function that transforms the raw OWID data into something we can use in the app
Args:
Original DataFrame from csv
Returns:
Processed / cleaned DataFrame
"""
# loop through and subset each country to a list
country_dfs = []
import pandas as pd
import numpy as np
import plotly.graph_objs as go
from sklearn.decomposition import PCA
def musixmatch_scatplot(df,df_embed,n_components=2):
"""Comprehensive function that adds components to a scatter plot. Requires musixmatch_PCA function """
df = musixmatch_PCA(df,df_embed,n_components) #gets PCA