Created
October 10, 2018 06:13
-
-
Save adam-phillipps/60fdbb6d35e3216cc6b40c6723241a45 to your computer and use it in GitHub Desktop.
cleaned up a bit
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
From a raw corpus, iterate through each article, do some basic preprocessing and | |
yield each sentence. | |
""" | |
def create_iterable_corpus(raw_corpus): | |
pdb.set_trace() | |
for article in raw_corpus: | |
# concatenate all section titles and texts of each Wikipedia article into a single "sentence" | |
doc = '\n'.join(itertools.chain.from_iterable(zip(article['section_titles'], article['section_texts']))) | |
yield preprocess_string(doc) | |
if __name__ == "__main__": | |
DEFAULT_VEC_SIZE = 500 | |
# RAW_CORPUS_URL = 'text8' | |
RAW_CORPUS_URL = 'wiki-english-20171001' | |
# RAW_CORPUS_URL = 'wiki-en-20171001.txt' | |
RESULTS_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), "results")) | |
# RESULTS_DIR = "s3://recon-artifacts" | |
_curr_path = os.path.abspath(os.path.dirname(__file__)) | |
logger = logging.getLogger(__name__) | |
logging.basicConfig(format='%(asctime)s : %(threadName)s : %(levelname)s : %(message)s', level=logging.INFO) | |
logger.info("running %s", " ".join(sys.argv)) | |
seterr(all='raise') # don't ignore numpy errors | |
args = gather_command_line_args(argparse.ArgumentParser()) | |
if not os.path.exists(os.path.join(_curr_path, args.corpus_file)): | |
if args.raw_url: | |
logger.info("Downloading raw corpus file for: ", args.raw_url) | |
raw_corpus = api.load(args.raw_url) | |
logger.info("Download complete") | |
elif args.raw_file: | |
logger.info("Loading corpus from: ", args.raw_file) | |
raw_corpus = api.load(args.raw_file) | |
logger.info("Load complete") | |
logger.info("Saving corpus to disk") | |
save_as_line_sentence(create_iterable_corpus(raw_corpus), args.corpus_file) | |
if os.path.exists(os.path.join(os.path.dirname(__file__), args.model)): | |
logger.info("Loading existing model from: ", args.model) | |
pdb.set_trace() | |
model = Doc2Vec.load(args.model) | |
else: | |
logger.info("Creating new Doc2Vec model from corpus_file: ", args.corpus_file) | |
model = Doc2Vec(corpus_file=args.corpus_file, | |
workers=args.workers, | |
epoch=args.epoch, | |
vector_size=args.vec_size) | |
model.save(args.model) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment