Skip to content

Instantly share code, notes, and snippets.

@yuchenlin
Created December 20, 2020 23:50
Show Gist options
  • Save yuchenlin/9fceb3533b1a0a4bcc4bf87ab2aa2d21 to your computer and use it in GitHub Desktop.
Save yuchenlin/9fceb3533b1a0a4bcc4bf87ab2aa2d21 to your computer and use it in GitHub Desktop.
Text Clustering with Sentence BERT
from sentence_transformers import SentenceTransformer # pip install -U sentence-transformers
from sklearn.cluster import KMeans
from collections import defaultdict
INPUT_FILE = "/tmp/test_input.txt"
with open(INPUT_FILE, "r") as f:
lines = f.read().splitlines()
print(len(lines))
corpus = [line.strip().lower() for line in lines]
embedder = SentenceTransformer('distilbert-base-nli-stsb-mean-tokens')
corpus_embeddings = embedder.encode(corpus, show_progress_bar=True, batch_size=8)
### KMEANS clustering
num_clusters = 100
clustering_model = KMeans(n_clusters=num_clusters)
clustering_model.fit(corpus_embeddings)
cluster_assignment = clustering_model.labels_
clusters = [[] for _ in range(len(cluster_assignment))]
for sent_id, cluster_label in enumerate(cluster_assignment):
clusters[cluster_label].append(corpus[sent_id])
clusters.sort(key=lambda x:len(x), reverse=True)
# Ouput
cnt_gourps = 0
text = ""
for c in range(len(clusters)):
if clusters[c]:
text += "\n" + "-"*50 + "\n"
text += "Cluster:%d\n"%c
text += "\n".join(clusters[c])
if len(clusters[c])>=2:
cnt_gourps += 1
print(cnt_gourps)
with open("/tmp/test_cluter.txt", "w") as f:
f.write(text)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment