Last active
May 19, 2016 16:27
-
-
Save y3nr1ng/9dac70b48dd3ef21e666858d67031d86 to your computer and use it in GitHub Desktop.
Task 2 - Doc2Vec
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3 | |
import os, argparse, logging | |
# get cpu cores | |
import multiprocessing | |
# doc2vec | |
from gensim.models import Doc2Vec | |
from gensim.models.doc2vec import LabeledSentence | |
from gensim import utils | |
# array storage | |
import numpy | |
# classification | |
from sklearn.linear_model import SGDClassifier | |
# model persistence | |
from sklearn.externals import joblib | |
# load model information | |
import pickle | |
# custom score estimation | |
import evaluate | |
TOTAL_EMOTICON_TYPES = 40 | |
logging.basicConfig(format='%(asctime)s [%(levelname)s] %(message)s') | |
def estimate_sent_cnt(file_path) : | |
sent_cnt = [0 for i in range(TOTAL_EMOTICON_TYPES+1)] | |
with open(file_path, 'r') as infile : | |
for line in infile : | |
try : | |
emot, text = line.strip().split('\t', maxsplit=1) | |
except ValueError : | |
continue | |
# increment the counter | |
sent_cnt[int(emot)] += 1 | |
return sent_cnt | |
def extract_vec(model, sent_cnt, vec_dim, logger=None) : | |
# ignore the first element, since it contains the test data | |
total_sent_cnt = sum(sent_cnt[1:]) | |
vec_array = numpy.zeros((total_sent_cnt, vec_dim)) | |
vec_label = numpy.zeros(total_sent_cnt) | |
curr_sent_idx = 0 | |
for emot, index_limit in enumerate(sent_cnt) : | |
# ignore the first element in the counter list | |
if emot == 0 : | |
continue | |
if logger : | |
logger.info('... processing emoticon {:d}'.format(emot)) | |
for i in range(1,index_limit+1) : | |
prefix = 'EMOTICON_{:d}_{:d}'.format(emot, i) | |
vec_array[curr_sent_idx] = model.docvecs[prefix] | |
vec_label[curr_sent_idx] = emot | |
# increment the overall counter | |
curr_sent_idx += 1 | |
return (vec_array, vec_label) | |
def is_valid_datfile(file_path, exclude_keyword='test', logger=None) : | |
if exclude_keyword not in file_path : | |
return True | |
else : | |
if logger : | |
logger.warning(file_path + ' contains "' + exclude_keyword + '", IGNORED') | |
return False | |
def get_args() : | |
parser = argparse.ArgumentParser(description='Train the classifier using trained doc2vec model.') | |
parser.add_argument('--algo', '-a', dest='algo', | |
default='sgd', | |
help='training algorithm (currently unused)') | |
parser.add_argument('--workers', dest='n_workers', type=int, | |
default=multiprocessing.cpu_count(), | |
help='number of worker thread, default to all the cores') | |
parser.add_argument('--test', '-t', dest='test_accuracy', | |
action='store_true', | |
help='test for accuracy after the training') | |
parser.add_argument('--outdir', '-o', dest='out_dir', | |
default='/tmp2/b03902036', | |
help='destination directory for the model file') | |
parser.add_argument('--verbose', '-v', dest='verbose', | |
action='count', default=0, | |
help='control the display level of output logs') | |
parser.add_argument('mod_file', nargs='+', | |
help='Model file from doc2vec training') | |
return parser.parse_args() | |
if __name__ == '__main__' : | |
# parse the command line arguments | |
args = get_args() | |
# get the logger object | |
logger = logging.getLogger() | |
# set the log level | |
if args.verbose >= 2 : | |
logger.setLevel(logging.DEBUG) | |
elif args.verbose >= 1 : | |
logger.setLevel(logging.INFO) | |
else : | |
logger.setLevel(logging.WARNING) | |
if len(args.mod_file) > 1 : | |
logger.warning('additional model files are ignored except the first one') | |
args.mod_file = args.mod_file[0] | |
logger.info('loading model from "{:s}"'.format(args.mod_file)) | |
model = Doc2Vec.load(args.mod_file) | |
logger.info('loading relevant data about the model') | |
mif_base = os.path.splitext(args.mod_file)[0] | |
with open(mif_base + '.mif', 'rb') as in_file : | |
dat_file = pickle.load(in_file) | |
sent_cnt = pickle.load(in_file) | |
dim = pickle.load(in_file) | |
logger.info('... model of {:d} features with {:d} emoticons is loaded'.format(dim, len(sent_cnt)-1)) | |
# ignore files with 'test' keyword | |
dat_file[:] = [file_path for file_path in dat_file | |
if is_valid_datfile(file_path, logger=logger)] | |
for i, count in enumerate(sent_cnt) : | |
logger.debug(' Emot {:d}, {:d} sentences'.format(i, count)) | |
logger.info('extracting vectors from the model') | |
vector, label = extract_vec(model, sent_cnt, dim, logger=logger) | |
if args.verbose >= 2 : | |
# only show the intermediate result in DEBUG mode | |
sklearn_verbose = 1 | |
else : | |
sklearn_verbose = 0 | |
classifier = SGDClassifier(loss='log', verbose=sklearn_verbose, n_jobs=args.n_workers) | |
logger.info('training started') | |
classifier.fit(vector, label) | |
# save the model | |
new_filename = '{:s}-d{:d}.cls'.format(args.algo, dim) | |
new_filepath = os.path.join(args.out_dir, new_filename) | |
joblib.dump(classifier, new_filepath) | |
logger.info('classifier saved to {:s}'.format(new_filepath)) | |
# estimate the accuracy using assigned schema | |
if args.test_accuracy : | |
#accuracy = classifier.score(vector, label) | |
accuracy = evaluate.eval_accuracy(model, classifier, sent_cnt, logger=logger) | |
logger.info('accuracy = {:f}'.format(accuracy)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3 | |
# algorithm: DBOW | |
# dimension: 300 | |
# window: 10 | |
# epoch: 20 | |
import os, argparse, logging | |
# get cpu cores | |
import multiprocessing | |
# doc2vec | |
from gensim.models import Doc2Vec | |
from gensim.models.doc2vec import LabeledSentence | |
from gensim import utils | |
# array storage | |
import numpy | |
# random | |
from random import shuffle | |
# save model information | |
import pickle | |
TOTAL_EMOTICON_TYPES = 40 | |
logging.basicConfig(format='%(asctime)s [%(levelname)s] %(message)s') | |
class LabeledLineSentence(object) : | |
def __init__(self, file_list, logger=None) : | |
self.file_list = file_list | |
self.sentences = [] | |
self.logger = logger | |
def __iter__(self) : | |
# reset the counters | |
# space: 0 to <emot_cnt>, labeled: 1 to <emot_cnt> | |
self.sent_cnt = [0 for i in range(TOTAL_EMOTICON_TYPES+1)] | |
skipped_cnt = 0 | |
for file_path in self.file_list : | |
with open(file_path, 'r') as infile : | |
for line in infile : | |
try : | |
sid, emot, text = line.strip().split('\t', maxsplit=2) | |
except ValueError : | |
self.logger.error('value error occur at {:s}'.format(sid)) | |
skipped_cnt += 1 | |
continue | |
# increment the counter | |
self.sent_cnt[int(emot)] += 1 | |
label = 'EMOTICON_{:s}_{:d}'.format(emot, self.sent_cnt[int(emot)]) | |
if self.logger : | |
self.logger.debug('"{:s}" -> {:s}'.format(label, text)) | |
yield LabeledSentence(text.split(' '), [label]) | |
self.logger.warning('{:d} sentences are skipped due to value error'.format(skipped_cnt)) | |
def to_array(self) : | |
# wipe out the old data | |
self.sentences = [] | |
if self.logger : | |
self.logger.debug('calling internal iterator') | |
for labeled_sent in self.__iter__() : | |
self.sentences.append(labeled_sent) | |
if self.logger : | |
self.logger.debug('return a list of length {:d} as result'.format(len(self.sentences))) | |
return self.sentences | |
def sentences_perm(self) : | |
shuffle(self.sentences) | |
return self.sentences | |
def get_sent_cnt(self) : | |
return self.sent_cnt | |
def is_valid_infile(file_path, target_ext='.pro', logger=None) : | |
ext = os.path.splitext(file_path)[-1].lower() | |
if ext != '.pro' : | |
if logger : | |
logger.warning(file_path + ' is not ' + target_ext + ' file, IGNORED') | |
return False | |
else : | |
return True | |
def get_args() : | |
parser = argparse.ArgumentParser(description='Train the doc2vec model using pre-processed data.') | |
parser.add_argument('--algo', '-a', dest='algo', | |
default='dm', | |
help='training algorithm, DM (default) or DBOW') | |
parser.add_argument('--dim', '-d', dest='dim', type=int, | |
default=100, | |
help='dimension of the feature vector') | |
parser.add_argument('--window', '-w', dest='window', type=int, | |
default=10, | |
help='max distance between the predicted word and context words within a doc') | |
parser.add_argument('--workers', dest='n_workers', type=int, | |
default=multiprocessing.cpu_count(), | |
help='number of worker thread, default to all the cores') | |
parser.add_argument('--epoch', '-e', dest='epochs', type=int, | |
default=20, | |
help='number of epoch trained') | |
parser.add_argument('--outdir', '-o', dest='out_dir', | |
default='/tmp2/b03902036', | |
help='destination directory for the model file') | |
parser.add_argument('--verbose', '-v', dest='verbose', | |
action='count', default=0, | |
help='control the display level of output logs') | |
parser.add_argument('in_file', nargs='+', | |
help='file to perform the training') | |
return parser.parse_args() | |
if __name__ == '__main__' : | |
# parse the command line arguments | |
args = get_args() | |
# get the logger object | |
logger = logging.getLogger() | |
# set the log level | |
if args.verbose >= 2 : | |
logger.setLevel(logging.DEBUG) | |
elif args.verbose >= 1 : | |
logger.setLevel(logging.INFO) | |
else : | |
logger.setLevel(logging.WARNING) | |
# verify all the files are valid | |
args.in_file[:] = [file_path for file_path in args.in_file | |
if is_valid_infile(file_path, logger=logger)] | |
logger.info('loading data from {:d} files'.format(len(args.in_file))) | |
for i, file_path in enumerate(args.in_file) : | |
logger.info(' File {:d}, {:s}'.format(i+1, file_path)) | |
sentences = LabeledLineSentence(args.in_file, logger=logger) | |
if args.algo == 'dm' : | |
dm = 1 | |
elif args.algo == 'dbow' : | |
dm = 0 | |
else : | |
logger.error('invalid training algorithm "{:s}"'.format(args.algo)) | |
# build the vocabulary table | |
model = Doc2Vec(min_count=1, | |
dm=dm, | |
size=args.dim, | |
window=args.window, | |
sample=1e-5, | |
workers=args.n_workers) | |
logger.info('start building the vocabulary') | |
model.build_vocab(sentences.to_array()) | |
logger.info('begin doc2vec training') | |
for epoch in range(args.epochs) : | |
logger.info('... epoch {:d}, alpha = {:s}'.format(epoch, str(model.alpha))) | |
model.train(sentences.sentences_perm()) | |
# save the model | |
new_filename = '{:s}-d{:d}-w{:d}-e{:d}'.format(args.algo, args.dim, args.window, args.epochs) | |
new_filepath = os.path.join(args.out_dir, new_filename) | |
model.save(new_filepath + '.mod') | |
logger.info('model saved to {:s}.mod (.mif contains the info)'.format(new_filepath)) | |
# save relevant informations | |
with open(new_filepath + '.mif', 'wb') as out_file : | |
# the .pro used to train the doc2vec model | |
pickle.dump(args.in_file, out_file) | |
# the sentences | |
pickle.dump(sentences.get_sent_cnt(), out_file) | |
# the dimension of the model | |
pickle.dump(args.dim, out_file) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3 | |
import os, argparse, logging | |
# Doc2Vec model | |
from gensim.models import Doc2Vec | |
# load models store in pickle format | |
from sklearn.externals import joblib | |
# load model information | |
import pickle | |
logging.basicConfig(format='%(asctime)s [%(levelname)s] %(message)s') | |
def predict(model, classifier, emot, int_id, n_top) : | |
# get the feature from the model | |
sent_prefix = 'EMOTICON_{:d}_{:d}'.format(emot, int_id) | |
feature = model.docvecs[sent_prefix] | |
# estimate the probabilities on all classes | |
candidates = classifier.predict_proba(feature.reshape(1, -1)) | |
# only one sentences is sent to the predictor | |
candidates = candidates[0] | |
# get the top n candidates, candidate ID starts from 1 | |
candidates = sorted(range(1, len(candidates)+1), key=lambda i: candidates[i-1])[-n_top:] | |
candidates = candidates[::-1] | |
return candidates | |
def eval_accuracy(model, classifier, sent_cnt, eval_weight=[1, 0.5, 0.333], logger=None) : | |
total_score = 0 | |
total_trial = 0 | |
for emot_id, counter in enumerate(sent_cnt) : | |
# ignore category 0 | |
if emot_id == 0 : | |
continue | |
# ground_truth -> emoticon category, ground truth | |
# sentence_list -> all the sentences labeled with emot | |
for int_id in range(1, counter+1): | |
# acquire candidates | |
prediction = predict(model, | |
classifier, | |
emot=emot_id, | |
int_id=int_id, | |
n_top=len(eval_weight)) | |
# calculate the score for current sentence | |
for j, weight in enumerate(eval_weight) : | |
if prediction[j] == emot_id : | |
if logger : | |
logger.debug('%d -> %s', emot_id, str(prediction)) | |
total_score += weight | |
break | |
total_trial += 1 | |
# log the score | |
if logger and (total_trial % 10000 == 0) : | |
logger.info('%.5f / %d = %.5f', total_score, total_trial, | |
(total_score/total_trial)) | |
return (total_score/total_trial) | |
def get_args() : | |
parser = argparse.ArgumentParser(description='Evaluation tool for the accuracy of trained models.') | |
parser.add_argument('--verbose', '-v', dest='verbose', | |
action='count', default=0, | |
help='control the display level of output logs') | |
parser.add_argument('--mod_file', '-m', dest='mod_file', nargs='+', | |
help='doc2vec model') | |
parser.add_argument('--cls_file', '-c', dest='cls_file', nargs='+', | |
help='Saved classifier') | |
return parser.parse_args() | |
if __name__ == '__main__' : | |
# parse the command line arguments | |
args = get_args() | |
# get the logger object | |
logger = logging.getLogger() | |
# set the log level | |
if args.verbose >= 2 : | |
logger.setLevel(logging.DEBUG) | |
elif args.verbose >= 1 : | |
logger.setLevel(logging.INFO) | |
else : | |
logger.setLevel(logging.WARNING) | |
if len(args.mod_file) > 1 : | |
logger.warning('additional model files are ignored except the first one') | |
args.mod_file = args.mod_file[0] | |
logger.info('loading model from "{:s}"'.format(args.mod_file)) | |
model = Doc2Vec.load(args.mod_file) | |
logger.info('loading relevant data about the model') | |
mif_base = os.path.splitext(args.mod_file)[0] | |
with open(mif_base + '.mif', 'rb') as in_file : | |
# IGNORED | |
dat_file = pickle.load(in_file) | |
# used by the evaluation | |
sent_cnt = pickle.load(in_file) | |
# IGNORED | |
dim = pickle.load(in_file) | |
logger.info('... model of {:d} features with {:d} emoticons is loaded'.format(dim, len(sent_cnt)-1)) | |
if len(args.cls_file) > 1 : | |
logger.warning('additional classifiers are ignored except the first one') | |
args.cls_file = args.cls_file[0] | |
logger.info('loading classifier from "{:s}"'.format(args.cls_file)) | |
classifier = joblib.load(args.cls_file) | |
accuracy = eval_accuracy(model, classifier, sent_cnt, logger=logger) | |
logger.info('accuracy = %f', accuracy) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3 | |
import os, sys, argparse, logging | |
# Doc2Vec model | |
from gensim.models import Doc2Vec | |
# load models store in pickle format | |
from sklearn.externals import joblib | |
import random | |
TOTAL_EMOTICON_TYPES = 40 | |
logging.basicConfig(format='%(asctime)s [%(levelname)s] %(message)s') | |
def predict(model, classifier, emot, int_id, n_top) : | |
# get the feature from the model | |
sent_prefix = 'EMOTICON_{:d}_{:d}'.format(emot, int_id) | |
feature = model.docvecs[sent_prefix] | |
# estimate the probabilities on all classes | |
candidates = classifier.predict_proba(feature.reshape(1, -1)) | |
# only one sentences is sent to the predictor | |
candidates = candidates[0] | |
# get the top n candidates, candidate ID starts from 1 | |
candidates = sorted(range(1, len(candidates)+1), key=lambda i: candidates[i-1])[-n_top:] | |
candidates = candidates[::-1] | |
return candidates | |
def batch_predict(model, classifier, ids, n_top=3, logger=None) : | |
for sid, emot, int_id in ids : | |
if int_id < 0 : | |
prediction = random.sample(range(1, TOTAL_EMOTICON_TYPES), n_top) | |
random_flag = 'RNG' | |
else : | |
# acquire candidates | |
prediction = predict(model, | |
classifier, | |
emot, | |
int_id, | |
n_top) | |
random_flag = '' | |
# log the prediction | |
if logger : | |
logger.debug('{:s} {:s} {:s}'.format(sid, str(prediction), random_flag)) | |
yield (sid, prediction) | |
def load_id_lut(dat_path) : | |
id_lut = [] | |
# counter for the internal ID | |
sent_int_id = [0 for i in range(TOTAL_EMOTICON_TYPES+1)] | |
with open(dat_path, 'r') as in_file : | |
for line in in_file : | |
# ignore lines with empty text after the filtering process | |
sid, content = line.strip().split('\t', maxsplit=1) | |
try : | |
emot, txt= content.strip().split('\t', maxsplit=1) | |
sent_int_id[int(emot)] += 1 | |
int_id = sent_int_id[int(emot)] | |
except ValueError : | |
int_id = -1 | |
id_lut.append((sid, int(emot), int_id)) | |
return id_lut | |
def get_args() : | |
parser = argparse.ArgumentParser(description='Generate answer file for Kaggle submission.') | |
parser.add_argument('--verbose', '-v', dest='verbose', | |
action='count', default=0, | |
help='control the display level of output logs') | |
parser.add_argument('--mod_file', '-m', dest='mod_file', nargs=1, | |
help='doc2vec model') | |
parser.add_argument('--cls_file', '-c', dest='cls_file', nargs=1, | |
help='Saved classifier') | |
parser.add_argument('--outdir', '-o', dest='out_dir', | |
default='/tmp2/b03902036', | |
help='destination directory for the model file') | |
parser.add_argument('dat_file', nargs='+', | |
help='Test data') | |
return parser.parse_args() | |
if __name__ == '__main__' : | |
# parse the command line arguments | |
args = get_args() | |
# get the logger object | |
logger = logging.getLogger() | |
# set the log level | |
if args.verbose >= 2 : | |
logger.setLevel(logging.DEBUG) | |
elif args.verbose >= 1 : | |
logger.setLevel(logging.INFO) | |
else : | |
logger.setLevel(logging.WARNING) | |
if len(args.mod_file) > 1 : | |
logger.warning('additional model files are ignored except the first one') | |
args.mod_file = args.mod_file[0] | |
logger.info('loading model from "{:s}"'.format(args.mod_file)) | |
model = Doc2Vec.load(args.mod_file) | |
if len(args.cls_file) > 1 : | |
logger.warning('additional classifiers are ignored except the first one') | |
args.cls_file = args.cls_file[0] | |
logger.info('loading classifier from "{:s}"'.format(args.cls_file)) | |
classifier = joblib.load(args.cls_file) | |
for in_file in args.dat_file : | |
logger.info('loading test data from "{:s}"'.format(in_file)) | |
id_lut = load_id_lut(in_file) | |
# create answer file filepath | |
basename = os.path.basename(in_file) | |
new_filename = os.path.splitext(basename)[0] + '.ans' | |
new_filepath = os.path.join(args.out_dir, new_filename) | |
with open(new_filepath, 'w') as out_file : | |
out_file.write('Id,Emoticon\n') | |
for sid, prediction in batch_predict(model, classifier, id_lut, logger=logger) : | |
# turn the predictions into a single string for printing | |
prediction = ' '.join([str(x) for x in prediction]) | |
out_file.write('{:s},{:s}\n'.format(sid, prediction)) | |
logger.info('saved to {:s}'.format(new_filepath)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3 | |
import os, argparse, logging | |
# word segmentation | |
import jieba | |
# text filtering | |
import re, string, zhon.hanzi | |
logging.basicConfig(format='%(asctime)s [%(levelname)s] %(message)s') | |
class SentenceLoader(object) : | |
def __init__(self, | |
file_path, | |
keep_punc, | |
tokenizer=None, base_marker='EMOTICON', | |
logger=None) : | |
self.file_path = file_path | |
self.keep_punc = keep_punc | |
self.tokenizer = tokenizer | |
self.base_marker = base_marker | |
self.emot_marker = { i : '{0}_+{1}'.format(self.base_marker, i) for i in range(1,41) } | |
self.logger = logger | |
def __remove_punc(self, s) : | |
if not self.keep_punc : | |
s = re.sub('[%s]' % string.punctuation, '', s) | |
s = re.sub('[%s]' % zhon.hanzi.punctuation, '', s) | |
return s | |
def __iter__(self) : | |
with open(self.file_path, 'r') as infile : | |
for line in infile : | |
sid, emot, text = line.strip().split('\t', maxsplit=2) | |
text = self.__remove_punc(text) | |
if self.tokenizer : | |
# remove the base marker | |
pos = text.index(self.base_marker) | |
text = text[:pos] + text[pos+len(self.base_marker):] | |
# cut the sentences using tokenizer, and ignore blanks | |
text = list(w for w in self.tokenizer.cut(text) if w != ' ') | |
text = ' '.join(text) | |
self.logger.debug(text) | |
yield sid, emot, text | |
else: | |
raise ValueError('no emoticon label in ' + str(sid)) | |
def get_args() : | |
parser = argparse.ArgumentParser(description='Pre-process the raw .tsv data for further usage.') | |
parser.add_argument('--dict', '-d', dest='dict_file', | |
default='./dict-120k.txt', | |
help='file path to the dictionary') | |
parser.add_argument('--keep-punc', '-p', dest='keep_punc', | |
action='store_true', | |
help='keep punctuations') | |
parser.add_argument('--outdir', '-o', dest='out_dir', | |
default='/tmp2/b03902036', | |
help='destination directory for the processed file') | |
parser.add_argument('--verbose', '-v', dest='verbose', | |
action='count', default=0, | |
help='control the display level of output logs') | |
parser.add_argument('in_file', nargs='+', | |
help='file to perform pre-processing') | |
return parser.parse_args() | |
if __name__ == '__main__' : | |
# parse the command line arguments | |
args = get_args() | |
# get the logger object | |
logger = logging.getLogger() | |
# set the log level | |
if args.verbose >= 2 : | |
logger.setLevel(logging.DEBUG) | |
elif args.verbose >= 1 : | |
logger.setLevel(logging.INFO) | |
else : | |
logger.setLevel(logging.WARNING) | |
# manually initialize for customization | |
jieba.set_dictionary(args.dict_file) | |
jieba.enable_parallel(8) | |
tokenizer = jieba.Tokenizer() | |
tokenizer.tmp_dir = '.' | |
for file_path in args.in_file : | |
logger.info('processing "' + file_path + '"...') | |
basename = os.path.basename(file_path) | |
if args.keep_punc : | |
punc_stat = '-punc' | |
else : | |
punc_stat = '-nopunc' | |
new_filename = os.path.splitext(basename)[0] + punc_stat + '.pro' | |
sentences = SentenceLoader(file_path, | |
args.keep_punc, | |
tokenizer=tokenizer, | |
logger=logger) | |
with open(os.path.join(args.out_dir, new_filename), 'w') as out : | |
for sid, emot, text in sentences : | |
out.write('{:s}\t{:s}\t{:s}\n'.format(sid, emot, text)) | |
logger.info('...complete') |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment