Created
April 17, 2021 16:30
-
-
Save alexeyev/3f99304bddcd44eafd2e79b70e0288d9 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from functools import lru_cache | |
from nltk import TweetTokenizer, WordNetLemmatizer | |
from tqdm import tqdm | |
from gsdmm import MovieGroupProcess | |
from sklearn.datasets import fetch_20newsgroups | |
from nltk.corpus import stopwords | |
import pickle | |
import nltk | |
nltk.download("stopwords") | |
tt = TweetTokenizer(preserve_case=False) | |
wnl = WordNetLemmatizer() | |
stops = set(stopwords.words("english")) | |
@lru_cache(100000) | |
def lemmatize(w): | |
return wnl.lemmatize(w) | |
def tokenize(txt): | |
txt = "\n".join([line.strip() for line in txt.split("\n") if ":" not in line]) | |
return [lemmatize(t) for t in tt.tokenize(txt) if str.isalpha(t) and not t in stops] | |
newsgroups_train = [tokenize(txt) for txt in tqdm(fetch_20newsgroups(subset="train").data[:100])] | |
print("Texts prepared.") | |
# print("\n".join(newsgroups_train)) | |
print("---") | |
mgp = MovieGroupProcess(K=40, alpha=0.1, beta=0.1, n_iters=30) | |
vocab_size = len(set([w for txt in newsgroups_train for w in txt])) | |
y = mgp.fit(newsgroups_train, vocab_size) | |
with open("trained.pickle", "wb") as wf: | |
pickle.dump([mgp, y], wf) | |
for doc, label in zip(newsgroups_train, y): | |
print(label, " ".join(doc[:20])) | |
for map in mgp.cluster_word_distribution: | |
print(map) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment