Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save phileas-condemine/82147a9cc9c2c38b87885c88db4e62ac to your computer and use it in GitHub Desktop.
Save phileas-condemine/82147a9cc9c2c38b87885c88db4e62ac to your computer and use it in GitHub Desktop.
import pandas as pd
from covea.claudia.core.nlp.etudeponctuelle import save_as_table, load_dataframe_from_delta
from netme0a.settings.paths_to_data import *
from transformers import AutoTokenizer
import umap
import seaborn as sns
import matplotlib.pyplot as plt
from sentence_transformers import SentenceTransformer
model_name = "dangvantuan/sentence-camembert-large"
tokenizer = AutoTokenizer.from_pretrained(model_name)
def truncate_text(text):
tokens = tokenizer(text, truncation=True, max_length=512)
truncated_text = tokenizer.decode(tokens['input_ids'], skip_special_tokens=True)
return truncated_text
st_model = SentenceTransformer(model_name, trust_remote_code=True).to(0)
def embedder(texts_list,batch_size,device=0):
return st_model.encode(texts_list,batch_size=batch_size, convert_to_numpy = True, device=device)
data = load_dataframe_from_delta(spark, data_first_message_and_metadata_path)
data['annee'] = pd.to_datetime(data['DATE_CREA_FIL_DISC']).dt.year
count_by_year = data.groupby('annee').size().reset_index(name='N')
print(count_by_year)
print(len(data))
data = data[~data['CD_SOUS_THEM_FIL_DISC'].isin(["19_9"])]
print(len(data))
data = data[data['TT_CORP_MSG'].str.len() > 40]
data = data.rename(columns={"NU_FIL_DISC":"id","TT_CORP_MSG":"text"})
print(len(data))
data = data[data['annee'] >= 2022]
data = data.sample(20000).reset_index(drop=True)
id_verbatim = load_dataframe_from_delta(spark,data_first_message_and_metadata_path,['TT_CORP_MSG','NU_FIL_DISC'])
id_verbatim = id_verbatim.rename(columns={"TT_CORP_MSG":"text","NU_FIL_DISC":"id"})
id_verbatim['id'] = id_verbatim['id'].map(str)
annotations = load_dataframe_from_delta(spark, data_split_train_test_train_multilabel_intentions_prio_elody_path)
annotations = pd.merge(annotations,id_verbatim[['id','text']],on="id",how="left")
print(annotations[['value']].value_counts())
annotations = annotations.groupby(['text','id','type'])['value'].apply(list).reset_index(name='label_lib')
annotations['text'] = annotations['text'].map(truncate_text)
data['text'] = data['text'].map(truncate_text)
annotations_embeddings = embedder(annotations['text'].tolist(),batch_size=32)
data_embeddings = embedder(data['text'].tolist(),batch_size=32)
# Train UMAP on data_embeddings
umap_model = umap.UMAP()
data_umap = umap_model.fit_transform(data_embeddings)
# Predict on annotations_embeddings
annotations_umap = umap_model.transform(annotations_embeddings)
annotations['X'] = annotations_umap[:, 0]
annotations['Y'] = annotations_umap[:, 1]
x_deciles = annotations['X'].quantile([0.001, 0.999]).values
y_deciles = annotations['Y'].quantile([0.001, 0.999]).values
downsampling_nb = 40
annotations_ = annotations.explode("label_lib")
downsampling_id_train = annotations_.sample(frac=1).groupby("label_lib")[['id']].head(downsampling_nb).reset_index()['id'].tolist()
annotations_ = annotations_[annotations_['id'].isin(downsampling_id_train)].reset_index(drop=True)
# Create a seaborn scatter plot
plt.figure(figsize=(10, 8))
scatter_plot = sns.scatterplot(data=annotations_, x='X', y='Y', hue='label_lib', palette='colorblind', legend='full', hue_order=sorted(annotations_['label_lib'].unique()))
scatter_plot.set_title('Visualisation de la distribution des annotations')
scatter_plot.set_xlabel('UMAP axe 1')
scatter_plot.set_ylabel('UMAP axe 2')
# scatter_plot.set_xlim(x_deciles[0], x_deciles[1])
# scatter_plot.set_ylim(y_deciles[0], y_deciles[1])
scatter_plot.set_xlim(4, 18)
scatter_plot.set_ylim(2, 15)
plt.legend(title='label_lib', bbox_to_anchor=(1.05, 1), loc='upper left')
plt.savefig('./output/umap.png', bbox_inches='tight')
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment