Created
February 20, 2025 14:08
-
-
Save phileas-condemine/82147a9cc9c2c38b87885c88db4e62ac to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import pandas as pd | |
from covea.claudia.core.nlp.etudeponctuelle import save_as_table, load_dataframe_from_delta | |
from netme0a.settings.paths_to_data import * | |
from transformers import AutoTokenizer | |
import umap | |
import seaborn as sns | |
import matplotlib.pyplot as plt | |
from sentence_transformers import SentenceTransformer | |
model_name = "dangvantuan/sentence-camembert-large" | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
def truncate_text(text): | |
tokens = tokenizer(text, truncation=True, max_length=512) | |
truncated_text = tokenizer.decode(tokens['input_ids'], skip_special_tokens=True) | |
return truncated_text | |
st_model = SentenceTransformer(model_name, trust_remote_code=True).to(0) | |
def embedder(texts_list,batch_size,device=0): | |
return st_model.encode(texts_list,batch_size=batch_size, convert_to_numpy = True, device=device) | |
data = load_dataframe_from_delta(spark, data_first_message_and_metadata_path) | |
data['annee'] = pd.to_datetime(data['DATE_CREA_FIL_DISC']).dt.year | |
count_by_year = data.groupby('annee').size().reset_index(name='N') | |
print(count_by_year) | |
print(len(data)) | |
data = data[~data['CD_SOUS_THEM_FIL_DISC'].isin(["19_9"])] | |
print(len(data)) | |
data = data[data['TT_CORP_MSG'].str.len() > 40] | |
data = data.rename(columns={"NU_FIL_DISC":"id","TT_CORP_MSG":"text"}) | |
print(len(data)) | |
data = data[data['annee'] >= 2022] | |
data = data.sample(20000).reset_index(drop=True) | |
id_verbatim = load_dataframe_from_delta(spark,data_first_message_and_metadata_path,['TT_CORP_MSG','NU_FIL_DISC']) | |
id_verbatim = id_verbatim.rename(columns={"TT_CORP_MSG":"text","NU_FIL_DISC":"id"}) | |
id_verbatim['id'] = id_verbatim['id'].map(str) | |
annotations = load_dataframe_from_delta(spark, data_split_train_test_train_multilabel_intentions_prio_elody_path) | |
annotations = pd.merge(annotations,id_verbatim[['id','text']],on="id",how="left") | |
print(annotations[['value']].value_counts()) | |
annotations = annotations.groupby(['text','id','type'])['value'].apply(list).reset_index(name='label_lib') | |
annotations['text'] = annotations['text'].map(truncate_text) | |
data['text'] = data['text'].map(truncate_text) | |
annotations_embeddings = embedder(annotations['text'].tolist(),batch_size=32) | |
data_embeddings = embedder(data['text'].tolist(),batch_size=32) | |
# Train UMAP on data_embeddings | |
umap_model = umap.UMAP() | |
data_umap = umap_model.fit_transform(data_embeddings) | |
# Predict on annotations_embeddings | |
annotations_umap = umap_model.transform(annotations_embeddings) | |
annotations['X'] = annotations_umap[:, 0] | |
annotations['Y'] = annotations_umap[:, 1] | |
x_deciles = annotations['X'].quantile([0.001, 0.999]).values | |
y_deciles = annotations['Y'].quantile([0.001, 0.999]).values | |
downsampling_nb = 40 | |
annotations_ = annotations.explode("label_lib") | |
downsampling_id_train = annotations_.sample(frac=1).groupby("label_lib")[['id']].head(downsampling_nb).reset_index()['id'].tolist() | |
annotations_ = annotations_[annotations_['id'].isin(downsampling_id_train)].reset_index(drop=True) | |
# Create a seaborn scatter plot | |
plt.figure(figsize=(10, 8)) | |
scatter_plot = sns.scatterplot(data=annotations_, x='X', y='Y', hue='label_lib', palette='colorblind', legend='full', hue_order=sorted(annotations_['label_lib'].unique())) | |
scatter_plot.set_title('Visualisation de la distribution des annotations') | |
scatter_plot.set_xlabel('UMAP axe 1') | |
scatter_plot.set_ylabel('UMAP axe 2') | |
# scatter_plot.set_xlim(x_deciles[0], x_deciles[1]) | |
# scatter_plot.set_ylim(y_deciles[0], y_deciles[1]) | |
scatter_plot.set_xlim(4, 18) | |
scatter_plot.set_ylim(2, 15) | |
plt.legend(title='label_lib', bbox_to_anchor=(1.05, 1), loc='upper left') | |
plt.savefig('./output/umap.png', bbox_inches='tight') | |
plt.show() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment