Skip to content

Instantly share code, notes, and snippets.

@kitsamho
Last active April 10, 2023 08:37
Show Gist options
  • Select an option

  • Save kitsamho/42b803d3d3e292dfe5500b999b9f3939 to your computer and use it in GitHub Desktop.

Select an option

Save kitsamho/42b803d3d3e292dfe5500b999b9f3939 to your computer and use it in GitHub Desktop.
text classification loop for classifying bbc headlines using CLIP
def text_classification_loop(bbc_headlines, tokeniser, model):
"""
Displays a random bbc headline from a list of scraped headlines, prompts the user to enter some contrasting labels
for the headline, and uses the pre-trained model to predict the probabilities of the provided labels.
Returns the predicted probabilities in a dataframe that get visualised as a bar plot
Args:
bbc_headlines (List[str]): A list of BBC headlines to use for classification.
model: A pre-trained CLIP model.
tokenizer: A tokenizer for the CLIP model.
Returns:
None
"""
# Check if there is a saved headline in the app state, otherwise select a new one
c1, c2, c3 = st.columns((3, 2, 5))
if 'text_keep' not in st.session_state:
headline = get_random_element(bbc_headlines)
st.session_state['text_keep'] = headline
c1.subheader('Random BBC headline')
c1.subheader(f'_"{headline}"_')
st.markdown('#')
st.markdown('#')
else:
c1.subheader('Random BBC headline')
c1.subheader(f'_"{st.session_state["text_keep"]}"_')
st.markdown('#')
st.markdown('#')
# Prompt user to enter labels for the selected headline
text_input_string = st.text_input('Choose some labels for this text - seperate labels with a comma e.g. "business headline, sports headline"', "business headline, sports headline")
labels = [i for i in text_input_string.split(",")]
# Classify the selected headline and generate a plot of predicted probabilities
probs = classify_texts(labels, st.session_state['text_keep'], model, tokeniser)
df = results_to_dataframe(probs, labels)
c3.subheader('Predicted probabilities')
c3.plotly_chart(plot_results(df, x_label='labels', y_label='probabilities', color_discrete_sequence='lightblue'))
# Allow user to select a new headline to classify
more_headlines = st.empty()
next_headline = more_headlines.button('Get new headline')
if next_headline:
st.session_state.pop('text_keep')
more_headlines.empty()
st.experimental_rerun()
return
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment