-
-
Save behrica/91b3f958fad80247069ade3b96646dcf to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from nltk.probability import FreqDist | |
import math | |
import pickle | |
from top2vec import Top2Vec | |
import numpy as np | |
from gensim.utils import simple_preprocess | |
from gensim.parsing.preprocessing import strip_tags | |
from tqdm import tqdm | |
def default_tokenizer(doc): | |
# This part was copied from Top2Vec tokenizer, if you are using a specific tokenizer you should not use the default one when computing the measure | |
"""Tokenize documents for training and remove too long/short words""" | |
return simple_preprocess(strip_tags(doc), deacc=True) | |
def PWI(model, docs, num_topics=20, num_words=20): | |
""" | |
:param model: top2vec model | |
:param docs: list of strings | |
:param num_topics: number of topics to use in the computation | |
:param num_words: number of words to use | |
:return: PWI value | |
""" | |
model.hierarchical_topic_reduction(num_topics) | |
# This is used to tokenize the data and strip tags (as done in top2vec) | |
tokenized_data = [default_tokenizer(doc) for doc in docs] | |
# Computing all the word frequencies | |
# First I concatenate all the documents and use FreqDist to compute the frequency of each word | |
word_frequencies = FreqDist(np.concatenate(tokenized_data)) | |
# Computing the frequency of words per document | |
# Remember to change the tokenizer if you are using a different one to train the model | |
dict_docs_freqs = {} | |
for i, doc in enumerate(tqdm(docs)): | |
counter_dict = FreqDist(default_tokenizer(doc)) | |
if i not in dict_docs_freqs: | |
dict_docs_freqs[i] = counter_dict | |
PWI = 0.0 | |
p_d = 1 / len(docs) | |
# This will iterate through the whole dataset and query the topics of each document. | |
for i, doc in enumerate(tqdm(docs)): | |
topic_words, word_scores, topic_scores, topic_nums = model.query_topics(query=doc, num_topics=num_topics, | |
reduced=True) | |
# Words of the topic | |
# Topic scores is the topic importance for that document | |
for words, t_score in zip(topic_words, topic_scores): | |
for word in words[:num_words]: | |
if word not in dict_docs_freqs[i]: | |
# This is added just for some specific cases when we are using different collection to test | |
continue | |
# P(d,w) = P(d|w) * p(w) | |
p_d_given_w = dict_docs_freqs[i].freq(word) | |
p_w = word_frequencies.freq(word) | |
p_d_and_w = p_d_given_w * p_w | |
left_part = p_d_given_w * t_score | |
PWI += left_part * math.log(p_d_and_w / (p_w * p_d)) | |
return PWI | |
if __name__ == '__main__': | |
# Fetching the data for example | |
from sklearn.datasets import fetch_20newsgroups | |
newsgroups = fetch_20newsgroups(subset='all', remove=('headers', 'footers', 'quotes')) | |
# Training the model as presented in the original github repository | |
model = Top2Vec(documents=newsgroups.data, speed="learn", workers=8) | |
# Dumping the model | |
# pickle.dump(model, open('top2vec-20news.pkl', 'wb')) | |
# Loading model | |
# model = pickle.load(open('top2vec-20news.pkl', 'rb')) | |
print("PWI:", PWI(model=model, docs=newsgroups.data, num_topics=20, num_words=20)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment