Created
December 20, 2022 08:28
-
-
Save macleginn/a721361d971c99034188c7668500f8e1 to your computer and use it in GitHub Desktop.
Clusterisation of fine-grained CMP domains based on SBERT sentence similarities
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from collections import defaultdict | |
from itertools import combinations | |
import pandas as pd | |
import numpy as np | |
from sentence_transformers import SentenceTransformer, util | |
def compute_kernel_bias(vecs, k=None): | |
""" | |
Code taken from: https://github.com/bojone/BERT-whitening | |
""" | |
mu = vecs.mean(axis=0, keepdims=True) | |
cov = np.cov(vecs.T) | |
u, s, vh = np.linalg.svd(cov) | |
W = np.dot(u, np.diag(1 / np.sqrt(s))) | |
if k: | |
return W[:,:k], -mu | |
else: | |
return W, -mu | |
def transform_and_normalize(vecs, kernel=None, bias=None): | |
""" | |
Code taken from: https://github.com/bojone/BERT-whitening | |
""" | |
if not (kernel is None or bias is None): | |
vecs = (vecs + bias).dot(kernel) | |
return vecs / (vecs**2).sum(axis=1, keepdims=True)**0.5 | |
def sbert_representations(sentences, model_name): | |
model = SentenceTransformer(model_name) | |
return model.encode(sentences, show_progress_bar=True) | |
model_name = 'paraphrase-multilingual-mpnet-base-v2' | |
manifestos_df = pd.read_csv('../data/manifestos.csv', dtype='object') | |
# Skip the sentences whose domain labels appear less than N times in the data. | |
N = 30 | |
label_counts = dict(manifestos_df.label.value_counts()) | |
major_labels = set([k for k in label_counts if label_counts[k] >= N]) | |
manifestos_df = manifestos_df.loc[manifestos_df.label.map(lambda l: label_counts.get(l, 0) >= N)] | |
print(f'Using {manifestos_df.shape[0]} sentences with frequent labels') | |
domains = sorted(manifestos_df.label.unique()) | |
domain2indices = defaultdict(list) | |
for i, d in enumerate(manifestos_df.label): | |
domain2indices[d].append(i) | |
print('Computing embeddings...') | |
sbert_embeddings = sbert_representations( | |
list(manifestos_df.text), | |
model_name | |
) | |
kernel, bias = compute_kernel_bias(sbert_embeddings, k=None) | |
normalised_embeddings = transform_and_normalize(sbert_embeddings, kernel, bias) | |
print('Computing cosine similarities...') | |
similarities = util.cos_sim(normalised_embeddings, normalised_embeddings).numpy() | |
domain_distance_matrix = pd.DataFrame(1.0, index=domains, columns=domains) | |
for d1, d2 in combinations(domains, 2): | |
idx1 = domain2indices[d1] | |
idx2 = domain2indices[d2] | |
similarities_d = similarities[idx1, :] | |
similarities_d = similarities_d[:, idx2] | |
domain_distance_matrix.loc[d1, d2] = 1 - similarities_d.mean() | |
domain_distance_matrix.loc[d2, d1] = 1 - similarities_d.mean() | |
domain_distance_matrix.to_csv('domain_distance_matrix.csv') |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment