Last active
December 23, 2023 03:42
-
-
Save kkew3/b9fb85ef390685c13733a3e006a7e825 to your computer and use it in GitHub Desktop.
The approach to evaluate scikit-learn topic model in terms of coherence with gensim using existing vocabulary.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from collections import Counter | |
from typing import Dict, Union, List | |
import numpy as np | |
from scipy import sparse | |
import pandas as pd | |
import spacy | |
from sklearn.datasets import fetch_20newsgroups | |
from sklearn.decomposition import LatentDirichletAllocation | |
from gensim.models.coherencemodel import CoherenceModel | |
nlp = spacy.load('en_core_web_md', disable=['ner', 'parser']) | |
corpus = fetch_20newsgroups(remove=('headers', 'footers', 'quotes')).data | |
def map_filter_words(doc): | |
"""Filtering and lemmatization.""" | |
for word in doc: | |
if word.is_alpha and not word.is_stop: | |
yield word.lemma_ | |
texts = [] # tokenized corpus | |
tf = Counter() # global term frequency | |
for doc in map(nlp, corpus): | |
tf.update(map_filter_words(doc)) | |
texts.append(list(map_filter_words(doc))) | |
vocab = sorted(tf) # the vocabulary | |
doc_word = sparse.lil_matrix((len(texts), len(vocab)), dtype=int) | |
for i, doc in enumerate(texts): | |
doc_tf = Counter(doc) # term frequency per document | |
r = pd.Series(doc_tf).reindex(vocab).fillna(0).astype(int) | |
doc_word[i] = r.to_numpy() | |
doc_word = doc_word.tocsr() | |
lda = LatentDirichletAllocation() | |
lda.fit(doc_word) | |
### NOTE HERE | |
class DummyTopicModel: | |
"""Fake a topic model for gensim""" | |
def __init__(self, lam): | |
"""lam: the variational parameters for topic-word distribution""" | |
self.lam = lam / np.sum(lam, axis=1, keepdims=True) | |
def get_topics(self): | |
return self.lam | |
### NOTE HERE | |
class DummyDictionary: | |
def __init__(self, vocab: ty.List[str]): | |
self.token2id = {w: j for j, w in enumerate(vocab)} | |
self.id2token = vocab.copy() | |
def __getitem__(self, item): | |
return self.id2token[item] | |
def __contains__(self, item): | |
if isinstance(item, int): | |
return 0 <= item < len(self.id2token) | |
return False | |
cm = CoherenceModel( | |
model=DummyTopicModel(lda.components_), | |
texts=texts, | |
dictionary=DummyDictionary(vocab), | |
coherence='c_npmi', | |
) | |
coh = np.asarray(cm.get_coherence_per_topic()) | |
print('average topic coherence:', np.mean(coh)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment