Created
April 11, 2020 19:04
-
-
Save singhay/81e2096381ee27e0c84d6eecfa3b66b1 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# This is script is limited to single GPU at the moment due to LSH attention approximation | |
import argparse | |
import glob | |
import logging | |
import os | |
from typing import Dict, List | |
import torch | |
from reformer_pytorch import ReformerLM | |
from tokenizers import ByteLevelBPETokenizer | |
from torch.nn.utils.rnn import pad_sequence | |
from torch.utils.data import RandomSampler, DataLoader, SequentialSampler | |
from torch.utils.data.dataset import Dataset | |
from torch.utils.tensorboard import SummaryWriter | |
from tqdm import tqdm, trange | |
from transformers import AdamW, get_linear_schedule_with_warmup, WEIGHTS_NAME | |
from bertology.cdr_dataset import CdrJsonLDataset | |
from bertology.constants import TOKENS_COLNAME | |
from bertology.run_language_modeling import _sorted_checkpoints, _rotate_checkpoints | |
from bertology.run_language_modeling_scratch import mask_tokens | |
from bertology.utils import set_seed | |
logger = logging.getLogger(__name__) | |
class LineByLineTextDataset(Dataset): | |
def __init__(self, file_path, tokenizer): | |
self.tokenizer = tokenizer | |
self.data = CdrJsonLDataset.read_data(file_path) | |
def __len__(self): | |
"""Denotes the total number of samples""" | |
return self.data.shape[0] | |
def __getitem__(self, idx): | |
"""Generates one sample of data""" | |
if torch.is_tensor(idx): | |
idx = idx.tolist() | |
return self._tokenize_chunk_bpe(self.data.loc[idx, TOKENS_COLNAME]) | |
def _tokenize_chunk_bpe(self, doc): | |
"""Flatten and tokenize a document of sequences | |
:param doc: list[list[str]] cdr document that is sentenized -> tokenized using legacy pipeline | |
:return: torch tensor of (sequence_length) for DataLoader to consume | |
""" | |
# TODO: Instead of truncating by last words, take center of document | |
flat_doc = [token for sent in doc for token in sent] | |
return torch.tensor(self.tokenizer.encode(' '.join(flat_doc)).ids) | |
@staticmethod | |
def get_max_sequence_length(): | |
# TODO: Implement df.doc_length.max().tolist()[0] | |
return 2 ** 5 | |
def collate(examples: List[torch.Tensor]): | |
return pad_sequence(examples, batch_first=True, padding_value=tokenizer.token_to_id("<pad>")) | |
def evaluate(args, eval_dataset, model, tokenizer, loss_fn, prefix="") -> Dict: | |
eval_output_dir = args.output_dir | |
os.makedirs(eval_output_dir, exist_ok=True) | |
args.eval_batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpus) | |
# Note that DistributedSampler samples randomly | |
eval_sampler = SequentialSampler(eval_dataset) | |
eval_dataloader = DataLoader( | |
eval_dataset, sampler=eval_sampler, batch_size=args.eval_batch_size, collate_fn=collate, | |
) | |
# TODO: multi-gpu evaluate | |
if args.n_gpus > 1 and not isinstance(model, torch.nn.DataParallel): | |
model = torch.nn.DataParallel(model) | |
# Eval! | |
logger.info("***** Running evaluation {} *****".format(prefix)) | |
logger.info(" Num examples = %d", len(eval_dataset)) | |
logger.info(" Batch size = %d", args.eval_batch_size) | |
eval_loss = 0.0 | |
nb_eval_steps = 0 | |
model.eval() | |
for batch in tqdm(eval_dataloader, desc="Evaluating"): | |
inputs, labels = mask_tokens(batch, tokenizer, args) | |
inputs = inputs.to(args.device) | |
labels = labels.to(args.device) | |
with torch.no_grad(): | |
output = model(inputs) | |
loss_mx = labels != -100 | |
output_ids = output[loss_mx].view(-1, tokenizer.get_vocab_size()) | |
labels = labels[loss_mx].view(-1) | |
eval_loss += loss_fn(output_ids, labels).mean().item() | |
nb_eval_steps += 1 | |
eval_loss = eval_loss / nb_eval_steps | |
perplexity = torch.exp(torch.tensor(eval_loss)) | |
result = {"perplexity": perplexity.item(), "loss": eval_loss} | |
output_eval_file = os.path.join(eval_output_dir, prefix, "eval_results.txt") | |
with open(output_eval_file, "w") as writer: | |
logger.info("***** Eval results {} *****".format(prefix)) | |
for key in sorted(result.keys()): | |
logger.info(" %s = %s", key, str(result[key])) | |
writer.write("%s = %s\n" % (key, str(result[key]))) | |
return result | |
def train(args, train_dataset, model, tokenizer, loss_fn): | |
tb_writer = SummaryWriter() | |
args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpus) | |
train_sampler = RandomSampler(train_dataset) | |
train_dataloader = DataLoader( | |
train_dataset, sampler=train_sampler, batch_size=args.train_batch_size, collate_fn=collate | |
) | |
t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs | |
# Prepare optimizer and schedule (linear warmup and decay) | |
no_decay = ["bias", "LayerNorm.weight"] | |
optimizer_grouped_parameters = [ | |
{ | |
"params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], | |
"weight_decay": args.weight_decay, | |
}, | |
{"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], "weight_decay": 0.0}, | |
] | |
optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon) | |
scheduler = get_linear_schedule_with_warmup( | |
optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total | |
) | |
if args.fp16: | |
try: | |
from apex import amp | |
except ImportError: | |
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.") | |
model, optimizer = amp.initialize(model, optimizer, opt_level="O1") | |
# TODO: multi-gpu training (should be after apex fp16 initialization) | |
if args.n_gpus > 1: | |
model = torch.nn.DataParallel(model) | |
# Train! | |
logger.info("***** Running training *****") | |
logger.info(" Num examples = %d", len(train_dataset)) | |
logger.info(" Num Epochs = %d", args.num_train_epochs) | |
logger.info(" Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size) | |
logger.info( | |
" Total train batch size (w. parallel, distributed & accumulation) = %d", | |
args.train_batch_size | |
* args.gradient_accumulation_steps, | |
) | |
logger.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps) | |
logger.info(" Total optimization steps = %d", t_total) | |
global_step = 0 | |
epochs_trained = 0 | |
steps_trained_in_current_epoch = 0 | |
# Check if continuing training from a checkpoint | |
if args.model_name_or_path and os.path.exists(args.model_name_or_path): | |
try: | |
# set global_step to global_step of last saved checkpoint from model path | |
checkpoint_suffix = args.model_name_or_path.split("-")[-1].split("/")[0] | |
global_step = int(checkpoint_suffix) | |
epochs_trained = global_step // (len(train_dataloader) // args.gradient_accumulation_steps) | |
steps_trained_in_current_epoch = global_step % (len(train_dataloader) // args.gradient_accumulation_steps) | |
logger.info(" Continuing training from checkpoint, will skip to saved global_step") | |
logger.info(" Continuing training from epoch %d", epochs_trained) | |
logger.info(" Continuing training from global step %d", global_step) | |
logger.info(" Will skip the first %d steps in the first epoch", steps_trained_in_current_epoch) | |
except ValueError: | |
logger.info(" Starting fine-tuning.") | |
tr_loss, logging_loss = 0.0, 0.0 | |
model.zero_grad() | |
if args.evaluate_during_training: | |
eval_dataset = LineByLineTextDataset(args.eval_data_file, tokenizer) | |
train_iterator = trange(epochs_trained, int(args.num_train_epochs), desc="Epoch") | |
set_seed(args) # Added here for reproducibility | |
for _ in train_iterator: | |
epoch_iterator = tqdm(train_dataloader, desc="Iteration") | |
for step, batch in enumerate(epoch_iterator): | |
# Skip past any already trained steps if resuming training | |
if steps_trained_in_current_epoch > 0: | |
steps_trained_in_current_epoch -= 1 | |
continue | |
inputs, labels = mask_tokens(batch, tokenizer, args) | |
inputs = inputs.to(args.device) | |
labels = labels.to(args.device) | |
model.train() | |
output = model(inputs) | |
# only calculating loss on masked tokens | |
loss_mx = labels != -100 | |
output = output[loss_mx].view(-1, tokenizer.get_vocab_size()) | |
labels = labels[loss_mx].view(-1) | |
loss = loss_fn(output, labels) | |
if args.n_gpus > 1: | |
loss = loss.mean() # mean() to average on multi-gpu parallel training | |
if args.gradient_accumulation_steps > 1: | |
loss = loss / args.gradient_accumulation_steps | |
if args.fp16: | |
with amp.scale_loss(loss, optimizer) as scaled_loss: | |
scaled_loss.backward() | |
else: | |
loss.backward() | |
tr_loss += loss.item() | |
if (step + 1) % args.gradient_accumulation_steps == 0: | |
if args.fp16: | |
torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm) | |
else: | |
torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm) | |
optimizer.step() | |
scheduler.step() # Update learning rate schedule | |
model.zero_grad() | |
global_step += 1 | |
if args.logging_steps > 0 and global_step % args.logging_steps == 0: | |
# Log metrics | |
if args.evaluate_during_training: # Only evaluate when single GPU otherwise metrics may not average well | |
results = evaluate(args, eval_dataset, model, tokenizer, loss_fn) | |
for key, value in results.items(): | |
tb_writer.add_scalar("eval_{}".format(key), value, global_step) | |
tb_writer.add_scalar("lr", scheduler.get_lr()[0], global_step) | |
tb_writer.add_scalar("loss", (tr_loss - logging_loss) / args.logging_steps, global_step) | |
logging_loss = tr_loss | |
if args.save_steps > 0 and global_step % args.save_steps == 0: | |
checkpoint_prefix = "checkpoint" | |
# Save model checkpoint | |
output_dir = os.path.join(args.output_dir, "{}-{}".format(checkpoint_prefix, global_step)) | |
os.makedirs(output_dir, exist_ok=True) | |
model_to_save = ( | |
model.module if hasattr(model, "module") else model | |
) # Take care of distributed/parallel training | |
torch.save(model_to_save.state_dict(), os.path.join(output_dir, WEIGHTS_NAME)) | |
tokenizer.save(args.output_dir) | |
torch.save(args, os.path.join(output_dir, "training_args.bin")) | |
logger.info("Saved model, tokenizer and args to %s", output_dir) | |
_rotate_checkpoints(args, checkpoint_prefix) | |
torch.save(optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt")) | |
torch.save(scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt")) | |
logger.info("Saved optimizer and scheduler states to %s", output_dir) | |
tb_writer.close() | |
return global_step, tr_loss / global_step | |
def load_tokenizer(output_dir): | |
tokenizer = ByteLevelBPETokenizer( | |
f"{output_dir}/vocab.json", | |
f"{output_dir}/merges.txt", | |
) | |
return tokenizer | |
def get_model(): | |
model = ReformerLM( | |
num_tokens=VOCAB_SIZE, | |
dim=1024, | |
depth=1, | |
max_seq_len=MAX_SEQUENCE_LENGTH, | |
heads=8, | |
lsh_dropout=0.1, | |
ff_dropout=0.1, | |
post_attn_dropout=0.1, | |
layer_dropout=0.1, # layer dropout from 'Reducing Transformer Depth on Demand' paper | |
causal=False, # auto-regressive or not | |
bucket_size=64, # average size of qk per bucket, 64 was recommended in paper | |
n_hashes=4, # 4 is permissible per author, 8 is the best but slower | |
emb_dim=1024, # embedding factorization for further memory savings | |
ff_chunks=200, # number of chunks for feedforward layer, make higher if there are memory issues | |
attn_chunks=8, # process lsh attention in chunks, only way for memory to fit when scaling to 16k tokens | |
num_mem_kv=128, # persistent learned memory key values, from all-attention paper | |
twin_attention=False, # both branches of the reversible network will be attention | |
full_attn_thres=1024, # use full attention if context length is less than set value | |
reverse_thres=1024, | |
# turn off reversibility for 2x speed for sequence lengths shorter or equal to the designated value | |
use_scale_norm=True, # use scale norm from 'Transformers without tears' paper | |
one_value_head=True, # use one set of values for all heads from 'One Write-Head Is All You Need' | |
weight_tie=False, # tie parameters of each layer for no memory per additional depth | |
weight_tie_embedding=True, # use token embedding for projection of output, some papers report better results | |
use_full_attn=False | |
# only turn on this flag to override and turn on full attention for all sequence lengths. for comparison with LSH to show that it is working | |
, | |
axial_position_emb=True, | |
axial_position_shape=(8, 4), # the shape must multiply up to the max_seq_len (128 x 128 = 16384) | |
# axial_position_shape=(128, 128), # the shape must multiply up to the max_seq_len (128 x 128 = 16384) | |
axial_position_dims=(512, 512) # the dims must sum up to the model dimensions (512 + 512 = 1024) | |
) | |
return model | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser() | |
# Required parameters | |
parser.add_argument( | |
"--train_data_file", default=None, type=str, required=True, help="The input training data file (a text file)." | |
) | |
parser.add_argument( | |
"--output_dir", | |
type=str, | |
required=True, | |
help="The output directory where the model predictions and checkpoints will be written.", | |
) | |
# Other parameters | |
parser.add_argument( | |
"--eval_data_file", | |
default=None, | |
type=str, | |
help="An optional input evaluation data file to evaluate the perplexity on (a text file).", | |
) | |
parser.add_argument( | |
"--should_continue", action="store_true", help="Whether to continue from latest checkpoint in output_dir" | |
) | |
parser.add_argument( | |
"--model_name_or_path", | |
default=None, | |
type=str, | |
help="The model checkpoint for weights initialization. Leave None if you want to train a model from scratch.", | |
) | |
parser.add_argument( | |
"--mlm_probability", type=float, default=0.15, help="Ratio of tokens to mask for masked language modeling loss" | |
) | |
parser.add_argument( | |
"--config_name", | |
default=None, | |
type=str, | |
help="Optional pretrained config name or path if not the same as model_name_or_path. If both are None, initialize a new config.", | |
) | |
parser.add_argument( | |
"--tokenizer_path", | |
default=None, | |
type=str, | |
help="Optional pretrained tokenizer name or path if not the same as model_name_or_path. If both are None, initialize a new tokenizer.", | |
) | |
parser.add_argument("--do_train", action="store_true", help="Whether to run training.") | |
parser.add_argument("--do_eval", action="store_true", help="Whether to run eval on the dev set.") | |
parser.add_argument( | |
"--evaluate_during_training", action="store_true", help="Run evaluation during training at each logging step." | |
) | |
parser.add_argument("--per_gpu_train_batch_size", default=4, type=int, help="Batch size per GPU/CPU for training.") | |
parser.add_argument( | |
"--per_gpu_eval_batch_size", default=4, type=int, help="Batch size per GPU/CPU for evaluation." | |
) | |
parser.add_argument( | |
"--gradient_accumulation_steps", | |
type=int, | |
default=8, | |
help="Number of updates steps to accumulate before performing a backward/update pass.", | |
) | |
parser.add_argument("--learning_rate", default=5e-5, type=float, help="The initial learning rate for Adam.") | |
parser.add_argument("--weight_decay", default=0.0, type=float, help="Weight decay if we apply some.") | |
parser.add_argument("--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer.") | |
parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") | |
parser.add_argument( | |
"--num_train_epochs", default=1.0, type=float, help="Total number of training epochs to perform." | |
) | |
parser.add_argument("--warmup_steps", default=0, type=int, help="Linear warmup over warmup_steps.") | |
parser.add_argument("--logging_steps", type=int, default=500, help="Log every X updates steps.") | |
parser.add_argument("--save_steps", type=int, default=500, help="Save checkpoint every X updates steps.") | |
parser.add_argument( | |
"--save_total_limit", | |
type=int, | |
default=None, | |
help="Limit the total amount of checkpoints, delete the older checkpoints in the output_dir, does not delete by default", | |
) | |
parser.add_argument( | |
"--eval_all_checkpoints", | |
action="store_true", | |
help="Evaluate all checkpoints starting with the same prefix as model_name_or_path ending and ending with step number", | |
) | |
parser.add_argument( | |
"--overwrite_output_dir", action="store_true", help="Overwrite the content of the output directory" | |
) | |
parser.add_argument( | |
"--fp16", | |
action="store_true", | |
help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit", | |
) | |
args = parser.parse_args() | |
if args.eval_data_file is None and args.do_eval: | |
raise ValueError( | |
"Cannot do evaluation without an evaluation data file. Either supply a file to --eval_data_file " | |
"or remove the --do_eval argument." | |
) | |
if args.should_continue: | |
sorted_checkpoints = _sorted_checkpoints(args) | |
if len(sorted_checkpoints) == 0: | |
raise ValueError("Used --should_continue but no checkpoint was found in --output_dir.") | |
else: | |
args.model_name_or_path = sorted_checkpoints[-1] | |
if ( | |
os.path.exists(args.output_dir) | |
and os.listdir(args.output_dir) | |
and args.do_train | |
and not args.overwrite_output_dir | |
): | |
raise ValueError( | |
"Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome.".format( | |
args.output_dir | |
) | |
) | |
# Setup logging | |
logging.basicConfig( | |
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", | |
datefmt="%m/%d/%Y %H:%M:%S", | |
level=logging.INFO, | |
) | |
args.device = torch.device("cuda:0" if torch.cuda.is_available() and not args.no_cuda else "cpu") | |
args.n_gpus = torch.cuda.device_count() | |
logger.warning( | |
"Device: %s, n_gpus: %s, 16-bits training: %s", | |
args.device, | |
args.n_gpus, | |
args.fp16, | |
) | |
# Set seed | |
args.seed = 42 | |
set_seed(args) | |
tokenizer = load_tokenizer(args.tokenizer_path) | |
train_dataset = LineByLineTextDataset(args.train_data_file, tokenizer) | |
MAX_SEQUENCE_LENGTH = train_dataset.get_max_sequence_length() | |
tokenizer.enable_truncation(max_length=MAX_SEQUENCE_LENGTH) | |
VOCAB_SIZE = tokenizer.get_vocab_size() | |
model = get_model() | |
loss_fn = torch.nn.CrossEntropyLoss() | |
# Saving best-practices: if you use save() for the model and tokenizer, you can reload them | |
if args.do_train: | |
# Create output directory if needed | |
os.makedirs(args.output_dir, exist_ok=True) | |
global_step, tr_loss = train(args, train_dataset, model, tokenizer, loss_fn) | |
logger.info(" global_step = %s, average loss = %s", global_step, tr_loss) | |
model_to_save = ( | |
model.module if hasattr(model, "module") else model | |
) # Take care of distributed/parallel training | |
torch.save(model_to_save.state_dict(), os.path.join(args.output_dir, WEIGHTS_NAME)) | |
tokenizer.save(args.output_dir) | |
# Good practice: save your training arguments together with the trained model | |
torch.save(args, os.path.join(args.output_dir, "training_args.bin")) | |
logger.info("Model, tokenizer and args saved in %s", args.output_dir) | |
# Evaluation | |
results = {} | |
if args.do_eval: | |
tokenizer = load_tokenizer(args.output_dir) | |
eval_dataset = LineByLineTextDataset(args.eval_data_file, tokenizer) | |
tokenizer.enable_truncation(max_length=eval_dataset.get_max_sequence_length()) | |
checkpoints = [args.output_dir] | |
if args.eval_all_checkpoints: | |
checkpoints = list( | |
os.path.dirname(c) for c in sorted(glob.glob(args.output_dir + "/**/" + WEIGHTS_NAME, recursive=True)) | |
) | |
logger.info("Evaluate the following checkpoints: %s", checkpoints) | |
for checkpoint in checkpoints: | |
logger.info("Evaluate the following checkpoint: %s", checkpoint) | |
global_step = checkpoint.split("-")[-1] | |
prefix = checkpoint.split("/")[-1] if checkpoint.find("checkpoint") != -1 else "" | |
# Load a trained model and vocabulary that you have fine-tuned | |
model = get_model() | |
model.load_state_dict(torch.load(os.path.join(checkpoint, WEIGHTS_NAME))) | |
model.to(args.device) | |
if args.n_gpus > 1 and not isinstance(model, torch.nn.DataParallel): | |
model = torch.nn.DataParallel(model) | |
result = evaluate(args, eval_dataset, model, tokenizer, loss_fn, prefix=prefix) | |
result = dict((k + "_{}".format(global_step), v) for k, v in result.items()) | |
results.update(result) | |
logger.info("Eval results: {}".format(results)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment