Skip to content

Instantly share code, notes, and snippets.

@fsndzomga
Created September 12, 2023 11:39
Show Gist options
  • Save fsndzomga/5d0a5ee0ae869cf439e4c614563cf41e to your computer and use it in GitHub Desktop.
Save fsndzomga/5d0a5ee0ae869cf439e4c614563cf41e to your computer and use it in GitHub Desktop.
Training and validation
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.naive_bayes import MultinomialNB
from sklearn import metrics
vectorizer = TfidfVectorizer()
vectors = vectorizer.fit_transform(newsgroups_train_enriched)
clf = MultinomialNB(alpha=.01)
clf.fit(vectors, newsgroups_train.target)
newsgroups_test = fetch_20newsgroups(subset='test', remove=('headers', 'footers', 'quotes'), categories=categories)
vectors_test = vectorizer.transform(newsgroups_test.data)
pred = clf.predict(vectors_test)
print(metrics.f1_score(pred, newsgroups_test.target, average='macro'))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment