Created
March 19, 2015 20:19
-
-
Save stevenRush/1cb7070cadd7c69e9799 to your computer and use it in GitHub Desktop.
bigartm_run.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from __future__ import division | |
import os | |
import sys | |
import time | |
import math | |
import codecs | |
import glob | |
collection_name = 'mailru1' | |
home_folder = 'C://' | |
sys.path.append(home_folder + 'bigARTM/python') | |
sys.path.append(home_folder + 'bigARTM/python/artm') | |
import artm.messages_pb2 | |
import artm.library | |
####################################################################################################################### | |
def dump_theta(theta_matrix, file_name): | |
with open(file_name, 'w') as out: | |
for index, weights in enumerate(theta_matrix.item_weights): | |
print>>out, theta_matrix.item_id[index], | |
for weight in weights.value: | |
print>>out, "{0:0.4f}".format(weight), '\t', | |
print>>out, '\n', | |
####################################################################################################################### | |
online_timeout = 10 # option of frequency of updates in online mode | |
processors_count = 3 # number of Processor threads to be used in experiment | |
outer_iterations_count = 3 # number of iterations over whole collection. | |
inner_iterations_count = 10 # number of iteration over each document | |
kappa = 0.5 # parameter for coefficient of forgetting between sync-s | |
tau0 = 64 # parameter for coefficient of forgetting between sync-s | |
batch_size = 10000 # size of batches in documents | |
update_every = 3 # how many batches to process before next synchronization | |
save_and_test_model = False # save topic model into file for future usage | |
topics_count = 100 # number of topics we need to extract from collection | |
phi_smooth = 0.01 # regularizers coefficients for ARTM | |
theta_smooth = 0.01 | |
kernel_threshold = 0.25 # parameter of topic kernels, tokens with p(t|w) >= this value would form the kernels | |
num_top_tokens = 50 # number of top tokens to show | |
collections_path_prefix = r'C:\experiment\UCI_mailru_WH/' | |
batches_disk_path = collections_path_prefix + collection_name # path with batches | |
dictionary_file = collections_path_prefix + collection_name + '/dictionary' # path with dictionary | |
test_batches_folder = collections_path_prefix + collection_name + '/test_batches' # path with held-out batches for final perplexity estimation | |
theta_path = collections_path_prefix + 'theta' | |
results_folder = 'results_new' # name of results folder | |
####################################################################################################################### | |
if (not os.path.isdir(results_folder )): | |
os.mkdir(results_folder ) | |
os.chdir(results_folder ) | |
# open files for information about scores on each outer iteration | |
perplexity_file = open('perplexity.txt', 'w') | |
theta_sparsity_file = open('theta_sparsity.txt', 'w') | |
phi_sparsity_file = open('phi_sparsity.txt', 'w') | |
topic_kernel_size_file = open('topic_kernel_size.txt', 'w') | |
topic_kernel_purity_file = open('topic_kernel_purity.txt', 'w') | |
topic_kernel_contrast_file = open('topic_kernel_contrast.txt', 'w') | |
top_tokens_file = open('top_tokens.txt', 'w') | |
time_and_heldout_file = open('time_and_heldout.txt', 'w') | |
####################################################################################################################### | |
# create the configuration of Master | |
master_config = artm.messages_pb2.MasterComponentConfig() | |
master_config.disk_path = batches_disk_path | |
master_config.processors_count = processors_count | |
# read static dictionary message with information about collection | |
dictionary_message = artm.library.Library().LoadDictionary(dictionary_file) | |
with artm.library.MasterComponent(master_config) as master: | |
print 'Technical tasks and loading model initialization...' | |
# create static dictionary in Master | |
dictionary = master.CreateDictionary(dictionary_message) | |
# create and configure scores in Master | |
perplexity_score = master.CreatePerplexityScore() | |
sparsity_theta_score = master.CreateSparsityThetaScore() | |
sparsity_phi_score = master.CreateSparsityPhiScore() | |
topic_kernel_score_config = artm.messages_pb2.TopicKernelScoreConfig() | |
topic_kernel_score_config.probability_mass_threshold = kernel_threshold | |
topic_kernel_score = master.CreateTopicKernelScore(config = topic_kernel_score_config) | |
items_processed_score = master.CreateItemsProcessedScore() | |
top_tokens_score = master.CreateTopTokensScore(num_tokens=num_top_tokens) | |
# create and configure regularizers in Master | |
smooth_sparse_phi_reg = master.CreateSmoothSparsePhiRegularizer() | |
smooth_sparse_theta_reg = master.CreateSmoothSparseThetaRegularizer() | |
# create configuration of Model | |
model_config = artm.messages_pb2.ModelConfig() | |
model_config.topics_count = topics_count | |
model_config.inner_iterations_count = inner_iterations_count | |
# create Model according to its configuration | |
model = master.CreateModel(model_config) | |
# enable scores in the Model | |
model.EnableScore(perplexity_score) | |
model.EnableScore(sparsity_theta_score) | |
model.EnableScore(sparsity_phi_score) | |
model.EnableScore(topic_kernel_score) | |
model.EnableScore(items_processed_score) | |
model.EnableScore(top_tokens_score) | |
# enable regularizes in Model and initialize them with tau_coefficients | |
model.EnableRegularizer(smooth_sparse_phi_reg, phi_smooth) | |
model.EnableRegularizer(smooth_sparse_theta_reg, theta_smooth) | |
# set initial approximation for Phi matrix | |
model.Initialize(dictionary) | |
####################################################################################################################### | |
# global time counter: | |
elapsed_time = 0.0 | |
# number of documents, found in collection | |
max_items = 0 | |
first_sync = True | |
# start collection processing | |
print '\n=======Experiment was started=======\n' | |
for outer_iteration in range(0, outer_iterations_count): | |
if outer_iteration == outer_iterations_count - 1: | |
master.config().cache_theta = True | |
master.Reconfigure() | |
#master.InvokeIteration() | |
# master.WaitIdle() # wait for all batches are processed | |
#model.Synchronize() # synchronize model | |
batches = glob.glob(batches_disk_path + "/*.batch") | |
for batch_index, batch_filename in enumerate(batches): | |
print 'Processing batch {0}'.format(os.path.basename(batch_filename)) | |
master.AddBatch(batch_filename=batch_filename) | |
# The following rule defines when to retrieve Theta matrix. You decide :) | |
if ((batch_index + 1) % 2 == 0) or ((batch_index + 1) == len(batches)): | |
# master.WaitIdle() # wait for all batches are processed | |
# model.Synchronize(decay_weight=..., apply_weight=...) # uncomment for online algorithm | |
theta_args = artm.messages_pb2.GetThetaMatrixArgs() | |
theta_args.clean_cache = True | |
theta_matrix = master.GetThetaMatrix(model=model, args=theta_args) | |
filename = os.path.basename(batch_filename) | |
path_to_theta = collections_path_prefix + r'/theta/{0}.theta.txt'.format(filename) | |
dump_theta(theta_matrix, path_to_theta) | |
break | |
start_time = time.clock() | |
sync_count = -1 | |
# invoke one scan of the whole collection | |
master.InvokeIteration(1) | |
done = False | |
next_items_processed = batch_size * update_every | |
while (not done): | |
online_start_time = time.clock() | |
# Wait 'online_timeout' ms and check if the number of processed items had changed | |
done = master.WaitIdle(online_timeout) | |
current_items_processed = items_processed_score.GetValue(model).value | |
if done or (current_items_processed >= next_items_processed): # SINCHRONIZATION! | |
sync_count += 1 | |
update_coef = current_items_processed / (batch_size * update_every) | |
next_items_processed = current_items_processed + (batch_size * update_every) # set next model update | |
rho = pow(tau0 + update_coef, -kappa) # calculate rho | |
model.Synchronize(decay_weight=(0 if first_sync else (1-rho)), apply_weight=rho) # synchronize model | |
first_sync = False | |
# get current scores values | |
sparsity_phi_score_value = sparsity_phi_score.GetValue(model).value | |
sparsity_theta_score_value = sparsity_theta_score.GetValue(model).value | |
topic_kernel_score_value = topic_kernel_score.GetValue(model) | |
items_processed_score_value = items_processed_score.GetValue(model).value | |
perplexity_score_value = perplexity_score.GetValue(model = model).value | |
# increase time counter and save iteration time | |
iteration_time = time.clock() - online_start_time | |
elapsed_time += iteration_time | |
# display information into output | |
with codecs.open('output.txt', 'a', encoding='utf-8') as output: | |
print '==========================================================' | |
print 'Synchronization #' + '%2s' % str(sync_count) +\ | |
' | perplexity = ' + '%6s' %\ | |
(str(round(perplexity_score_value)) if (perplexity_score_value != -1) else 'NO') | |
print '----------------------------------------------------------' | |
print 'Phi sparsity = ' + '%7s' % str(round(sparsity_phi_score_value, 4) * 100) +\ | |
' % | ' + 'Theta sparsity = ' + '%7s' %\ | |
str(round(sparsity_theta_score_value, 4) * 100) + ' %' | |
print '----------------------------------------------------------' | |
print 'Size = ' + '%7s' % str(round(topic_kernel_score_value.average_kernel_size)) + ' | ' +\ | |
'Purity = ' + '%7s' % str(round(topic_kernel_score_value.average_kernel_purity, 3)) + ' | ' +\ | |
'Contrast = ' + '%7s' % str(round(topic_kernel_score_value.average_kernel_contrast, 3)) | |
print '----------------------------------------------------------' | |
print 'Elapsed time = ' + '%7s' % str(round(iteration_time, 2)) + ' sec.' + ' | ' +\ | |
'Items processed = ' + '%10s' % str(items_processed_score_value) | |
print '==========================================================\n\n' | |
# update current max documents count | |
if (items_processed_score_value > max_items): | |
max_items = items_processed_score_value | |
else: | |
items_processed_score_value += max_items * outer_iteration | |
# put information into corresponding files | |
perplexity_file.write('(' + str(items_processed_score_value) +\ | |
', ' + str(round(perplexity_score_value)) + ')\n') | |
phi_sparsity_file.write('(' + str(items_processed_score_value) +\ | |
', ' + str(round(sparsity_phi_score_value, 4) * 100) + ')\n') | |
theta_sparsity_file.write('(' + str(items_processed_score_value) +\ | |
', ' + str(round(sparsity_theta_score_value, 4) * 100) + ')\n') | |
topic_kernel_size_file.write('(' + str(items_processed_score_value) +\ | |
', ' + str(round(topic_kernel_score_value.average_kernel_size)) + ')\n') | |
topic_kernel_purity_file.write('(' + str(items_processed_score_value) +\ | |
', ' + str(round(topic_kernel_score_value.average_kernel_purity, 3)) + ')\n') | |
topic_kernel_contrast_file.write('(' + str(items_processed_score_value) +\ | |
', ' + str(round(topic_kernel_score_value.average_kernel_contrast, 3)) + ')\n') | |
print 'All elapsed time = ' + "{0:0.2f} sec".format(elapsed_time) | |
time_and_heldout_file.write('Elapsed time: ' + str(elapsed_time) + '\n') | |
# put top_tokens to file | |
top_tokens = top_tokens_score.GetValue(model) | |
topic_index = -1 | |
for i in range(0, top_tokens.num_entries): | |
if (top_tokens.topic_index[i] != topic_index): | |
topic_index = top_tokens.topic_index[i] | |
print>>top_tokens_file, "\n\n\n\n\n" #"\n\n\nTopic#" + str(topic_index+1) + ": \n\n", | |
print>>top_tokens_file, "%.8f\t" % top_tokens.weight[i], top_tokens.token[i] | |
# save model | |
if (save_and_test_model): | |
print 'Saving topic model... ', | |
with open(home_folder + 'Output.topic_model', 'wb') as binary_file: | |
binary_file.write(master.GetTopicModel(model).SerializeToString()) | |
# close all opened files and finish the program | |
perplexity_file.close() | |
theta_sparsity_file.close() | |
phi_sparsity_file.close() | |
topic_kernel_size_file.close() | |
topic_kernel_purity_file.close() | |
topic_kernel_contrast_file.close() | |
top_tokens_file.close() | |
############################################################################################################################ | |
if (save_and_test_model): | |
processors_count = 3 # change number of processors to increase speed if needed | |
# create the configuration of Master | |
test_master_config = artm.messages_pb2.MasterComponentConfig() | |
test_master_config.processors_count = processors_count | |
test_master_config.disk_path = test_batches_folder | |
with artm.library.MasterComponent(test_master_config) as test_master: | |
# read saved topic model from file | |
print 'Loading topic model...\n', | |
topic_model = artm.messages_pb2.TopicModel() | |
with open(home_folder + 'Output.topic_model', 'rb') as binary_file: | |
topic_model.ParseFromString(binary_file.read()) | |
# create static dictionary in Master | |
test_dictionary = test_master.CreateDictionary(dictionary_message) | |
# create perplexity score in Master | |
test_perplexity_score = test_master.CreatePerplexityScore() | |
# Create model for testing and enable perplexity scoring in it | |
test_model = test_master.CreateModel(topics_count = topics_count, inner_iterations_count = inner_iterations_count) | |
test_model.EnableScore(test_perplexity_score) | |
# restore previously saved topic model into test_master | |
test_model.Overwrite(topic_model) | |
# process batches, count perplexity and display the result | |
print 'Estimating perplexity on held-out batches...\n' | |
test_master.InvokeIteration() | |
test_master.WaitIdle() | |
perplexity_score_value = test_perplexity_score.GetValue(test_model).value | |
print "Held-out perplexity calculated in BigARTM = %.3f" % perplexity_score_value | |
time_and_heldout_file.write('Held-out perplexity: ' + str(round(perplexity_score_value)) + '\n') | |
time_and_heldout_file.close() | |
print '\n=======Experiment was finished=======\n' |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment