Created
June 11, 2017 14:11
-
-
Save justindavies/e1261f3584fc97b5598c142f8bb828d5 to your computer and use it in GitHub Desktop.
Create doc2vec model from data
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from gensim.models.doc2vec import LabeledSentence | |
from os import listdir | |
from os.path import isfile, join | |
import gensim | |
import DocIterator as DocIt | |
import MySQLdb | |
docLabels = [] | |
data = [] | |
conn = MySQLdb.connect(host="XXXXX", user="XXXXX", passwd="XXXXX", db="XXXXX", charset="utf8") | |
cur = conn.cursor() | |
cur.execute('SELECT tweet_id, text from articles where publish_date > "2016-10-01"') | |
for row in cur: | |
docLabels.append(str(row[0])) | |
docu = row[1].lower() | |
for char in ['.', '"', ',', '(', ')', '!', '?', ';', ':']: | |
docu = docu.replace(char, ' ' + char + ' ') | |
data.append(docu) | |
print("Examples: " + str(len(data))) | |
it = DocIt.DocIterator(data, docLabels) | |
#Doc2Vec(dm=1, dm_concat=1, size=100, window=5, negative=5, hs=0, min_count=2, workers=2), | |
#model = gensim.models.Doc2Vec(dm=1, dm_concat=1, size=50, window=5, negative=5, hs=0, min_count=2, workers=3, alpha=0.04, min_alpha=0.005) # use fixed learning rate | |
model = gensim.models.Doc2Vec(dm=1, dm_concat=1, size=100, window=5, negative=5, hs=0, min_count=2, workers=2) | |
model.build_vocab(it) | |
for epoch in range(100): | |
print("Epoch " + str(epoch)) | |
model.train(it) | |
print(model.docvecs.most_similar(["782943325909291008"], topn=10)) | |
print(model.docvecs.most_similar(["783641803358670848"], topn=10)) | |
model.alpha -= 0.002 # decrease the learning rate | |
print(model.alpha) | |
model.min_alpha = model.alpha # fix the learning rate, no deca | |
model.train(it) | |
print(model.docvecs.most_similar(["782943325909291008"], topn=10)) | |
print(model.docvecs.most_similar(["783641803358670848"], topn=10)) | |
model.save("doc2vec.model") | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment