Skip to content

Instantly share code, notes, and snippets.

@zomux
Created December 9, 2019 02:19
Show Gist options
  • Save zomux/6eb447a2b885951dd8b7b85ec29dbedb to your computer and use it in GitHub Desktop.
Save zomux/6eb447a2b885951dd8b7b85ec29dbedb to your computer and use it in GitHub Desktop.
#!/usr/bin/env python
# -*- coding: utf-8 -*-
from __future__ import division
from __future__ import print_function
import os, sys
from torch import optim
from argparse import ArgumentParser
sys.path.append(".")
import importlib
import torch
from nmtlab import MTTrainer
from nmtlab.utils import OPTS
from nmtlab.utils import is_root_node
from lib_treeautoencoder import TreeAutoEncoder
from lib_treelstm_dataloader import BilingualTreeDataLoader
from datasets import get_dataset_paths
DATA_ROOT = "./mydata"
ap = ArgumentParser()
ap.add_argument("--resume", action="store_true")
ap.add_argument("--test", action="store_true")
ap.add_argument("--test_nbest", action="store_true")
ap.add_argument("--train", action="store_true")
ap.add_argument("--evaluate", action="store_true")
ap.add_argument("--export_code", action="store_true")
ap.add_argument("--make_target", action="store_true")
ap.add_argument("--make_oracle_codes", action="store_true")
ap.add_argument("--all", action="store_true")
ap.add_argument("--opt_dtok", default="aspec", type=str)
ap.add_argument("--opt_seed", type=int, default=3)
ap.add_argument("--opt_hiddensz", type=int, default=256)
ap.add_argument("--opt_without_source", action="store_true")
ap.add_argument("--opt_codebits", type=int, default=0)
ap.add_argument("--opt_limit_tree_depth", type=int, default=0)
ap.add_argument("--opt_limit_datapoints", type=int, default=-1)
ap.add_argument("--quora_split", default="train")
ap.add_argument("--opt_load_pretrain", action="store_true")
ap.add_argument("--model_path",
default="{}/tree2code.pt".format(DATA_ROOT))
ap.add_argument("--result_path",
default="{}/tree2code.result".format(DATA_ROOT))
OPTS.parse(ap)
n_valid_per_epoch = 4
# Define datasets
DATA_ROOT = "./mydata"
dataset_paths = get_dataset_paths(DATA_ROOT, OPTS.dtok)
# Using horovod for training, automatically occupy all GPUs
# Determine the local rank
horovod_installed = importlib.util.find_spec("horovod") is not None
part_index = 0
part_num = 1
gpu_num = 1
print("Running on {} GPUs".format(gpu_num))
# Get codes for quora dataset
QUORA_HOME = "{}/data/quora".format(os.getenv("HOME"))
train_src_corpus = os.path.join(QUORA_HOME, "quora.{}.sp".format(OPTS.quora_split))
train_cfg_corpus = os.path.join(QUORA_HOME, "quora.{}.reference.cfg.txt".format(OPTS.quora_split))
train_tgt_corpus = os.path.join(QUORA_HOME, "quora.{}.reference.sp".format(OPTS.quora_split))
# Define dataset
dataset = BilingualTreeDataLoader(
src_path=train_src_corpus,
cfg_path=train_cfg_corpus,
src_vocab_path=dataset_paths["src_vocab_path"],
treelstm_vocab_path=dataset_paths["cfg_vocab_path"],
cache_path=None,
batch_size=128 * gpu_num,
part_index=part_index,
part_num=part_num,
max_tokens=60,
limit_datapoints=OPTS.limit_datapoints,
limit_tree_depth=OPTS.limit_tree_depth
)
# Load the tree autoencoder onto GPU
autoencoder = TreeAutoEncoder(dataset, hidden_size=OPTS.hiddensz, code_bits=OPTS.codebits, without_source=OPTS.without_source)
if torch.cuda.is_available():
autoencoder.cuda()
assert OPTS.export_code or OPTS.make_target
if OPTS.export_code or OPTS.all:
from nmtlab.utils import Vocab
import torch
assert os.path.exists(OPTS.model_path)
autoencoder.load(OPTS.model_path)
out_path = "{}/quora_{}.codes".format(DATA_ROOT, OPTS.quora_split)
if is_root_node():
autoencoder.train(False)
if torch.cuda.is_available():
autoencoder.cuda()
c = 0
c1 = 0
with open(out_path, "w") as outf:
print("code path", out_path)
for batch in dataset.yield_all_batches(batch_size=512):
src_lines, cfg_lines, src_batch, enc_tree, dec_tree = batch
out = autoencoder(src_batch.cuda(), enc_tree, dec_tree, return_code=True)
codes = out["codes"]
for i in range(len(src_lines)):
src = src_lines[i]
cfg = cfg_lines[i]
code = str(codes[i].int().cpu().numpy())
outf.write("{}\t{}\t{}\n".format(src, cfg, code))
outf.flush()
c += len(src_lines)
if c - c1 > 10000:
sys.stdout.write(".")
sys.stdout.flush()
c1 = c
sys.stdout.write("\n")
if OPTS.make_target or OPTS.all:
if is_root_node():
export_path = "{}/quora_{}.codes".format(DATA_ROOT, OPTS.quora_split)
out_path = "{}/quora_{}.tgt".format(DATA_ROOT, OPTS.quora_split)
print("out path", out_path)
export_map = {}
for line in open(export_path):
if len(line.strip().split("\t")) < 3:
continue
src, cfg, code = line.strip().rsplit("\t", maxsplit=2)
code_str = " ".join(["<c{}>".format(int(c) + 1) for c in code.split()])
export_map["{}\t{}".format(src, cfg)] = code_str
with open(out_path, "w") as outf:
src_path = train_src_corpus
tgt_path = train_tgt_corpus
cfg_path = train_cfg_corpus
for src, tgt, cfg in zip(open(src_path), open(tgt_path), open(cfg_path)):
key = "{}\t{}".format(src.strip(), cfg.strip())
if key in export_map:
outf.write("{} <eoc> {}\n".format(export_map[key], tgt.strip()))
else:
outf.write("\n")
@zomux
Copy link
Author

zomux commented Dec 9, 2019

Commands:

python gen_quora_codes.py --opt_dtok wmt14 --opt_codebits 8 --opt_limit_tree_depth 2 --opt_limit_datapoints 100000 --export_code --make_target --quora_split train
	python gen_quora_codes.py --opt_dtok wmt14 --opt_codebits 8 --opt_limit_tree_depth 2 --opt_limit_datapoints 100000 --export_code --make_target --quora_split valid
	python gen_quora_codes.py --opt_dtok wmt14 --opt_codebits 8 --opt_limit_tree_depth 2 --opt_limit_datapoints 100000 --export_code --make_target --quora_split test

@shota3506
Copy link

Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment