Skip to content

Instantly share code, notes, and snippets.

@singhay
Created April 11, 2020 19:04
Show Gist options
  • Save singhay/81e2096381ee27e0c84d6eecfa3b66b1 to your computer and use it in GitHub Desktop.
Save singhay/81e2096381ee27e0c84d6eecfa3b66b1 to your computer and use it in GitHub Desktop.
# This is script is limited to single GPU at the moment due to LSH attention approximation
import argparse
import glob
import logging
import os
from typing import Dict, List
import torch
from reformer_pytorch import ReformerLM
from tokenizers import ByteLevelBPETokenizer
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import RandomSampler, DataLoader, SequentialSampler
from torch.utils.data.dataset import Dataset
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm, trange
from transformers import AdamW, get_linear_schedule_with_warmup, WEIGHTS_NAME
from bertology.cdr_dataset import CdrJsonLDataset
from bertology.constants import TOKENS_COLNAME
from bertology.run_language_modeling import _sorted_checkpoints, _rotate_checkpoints
from bertology.run_language_modeling_scratch import mask_tokens
from bertology.utils import set_seed
logger = logging.getLogger(__name__)
class LineByLineTextDataset(Dataset):
def __init__(self, file_path, tokenizer):
self.tokenizer = tokenizer
self.data = CdrJsonLDataset.read_data(file_path)
def __len__(self):
"""Denotes the total number of samples"""
return self.data.shape[0]
def __getitem__(self, idx):
"""Generates one sample of data"""
if torch.is_tensor(idx):
idx = idx.tolist()
return self._tokenize_chunk_bpe(self.data.loc[idx, TOKENS_COLNAME])
def _tokenize_chunk_bpe(self, doc):
"""Flatten and tokenize a document of sequences
:param doc: list[list[str]] cdr document that is sentenized -> tokenized using legacy pipeline
:return: torch tensor of (sequence_length) for DataLoader to consume
"""
# TODO: Instead of truncating by last words, take center of document
flat_doc = [token for sent in doc for token in sent]
return torch.tensor(self.tokenizer.encode(' '.join(flat_doc)).ids)
@staticmethod
def get_max_sequence_length():
# TODO: Implement df.doc_length.max().tolist()[0]
return 2 ** 5
def collate(examples: List[torch.Tensor]):
return pad_sequence(examples, batch_first=True, padding_value=tokenizer.token_to_id("<pad>"))
def evaluate(args, eval_dataset, model, tokenizer, loss_fn, prefix="") -> Dict:
eval_output_dir = args.output_dir
os.makedirs(eval_output_dir, exist_ok=True)
args.eval_batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpus)
# Note that DistributedSampler samples randomly
eval_sampler = SequentialSampler(eval_dataset)
eval_dataloader = DataLoader(
eval_dataset, sampler=eval_sampler, batch_size=args.eval_batch_size, collate_fn=collate,
)
# TODO: multi-gpu evaluate
if args.n_gpus > 1 and not isinstance(model, torch.nn.DataParallel):
model = torch.nn.DataParallel(model)
# Eval!
logger.info("***** Running evaluation {} *****".format(prefix))
logger.info(" Num examples = %d", len(eval_dataset))
logger.info(" Batch size = %d", args.eval_batch_size)
eval_loss = 0.0
nb_eval_steps = 0
model.eval()
for batch in tqdm(eval_dataloader, desc="Evaluating"):
inputs, labels = mask_tokens(batch, tokenizer, args)
inputs = inputs.to(args.device)
labels = labels.to(args.device)
with torch.no_grad():
output = model(inputs)
loss_mx = labels != -100
output_ids = output[loss_mx].view(-1, tokenizer.get_vocab_size())
labels = labels[loss_mx].view(-1)
eval_loss += loss_fn(output_ids, labels).mean().item()
nb_eval_steps += 1
eval_loss = eval_loss / nb_eval_steps
perplexity = torch.exp(torch.tensor(eval_loss))
result = {"perplexity": perplexity.item(), "loss": eval_loss}
output_eval_file = os.path.join(eval_output_dir, prefix, "eval_results.txt")
with open(output_eval_file, "w") as writer:
logger.info("***** Eval results {} *****".format(prefix))
for key in sorted(result.keys()):
logger.info(" %s = %s", key, str(result[key]))
writer.write("%s = %s\n" % (key, str(result[key])))
return result
def train(args, train_dataset, model, tokenizer, loss_fn):
tb_writer = SummaryWriter()
args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpus)
train_sampler = RandomSampler(train_dataset)
train_dataloader = DataLoader(
train_dataset, sampler=train_sampler, batch_size=args.train_batch_size, collate_fn=collate
)
t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs
# Prepare optimizer and schedule (linear warmup and decay)
no_decay = ["bias", "LayerNorm.weight"]
optimizer_grouped_parameters = [
{
"params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
"weight_decay": args.weight_decay,
},
{"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], "weight_decay": 0.0},
]
optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
scheduler = get_linear_schedule_with_warmup(
optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total
)
if args.fp16:
try:
from apex import amp
except ImportError:
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
model, optimizer = amp.initialize(model, optimizer, opt_level="O1")
# TODO: multi-gpu training (should be after apex fp16 initialization)
if args.n_gpus > 1:
model = torch.nn.DataParallel(model)
# Train!
logger.info("***** Running training *****")
logger.info(" Num examples = %d", len(train_dataset))
logger.info(" Num Epochs = %d", args.num_train_epochs)
logger.info(" Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size)
logger.info(
" Total train batch size (w. parallel, distributed & accumulation) = %d",
args.train_batch_size
* args.gradient_accumulation_steps,
)
logger.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps)
logger.info(" Total optimization steps = %d", t_total)
global_step = 0
epochs_trained = 0
steps_trained_in_current_epoch = 0
# Check if continuing training from a checkpoint
if args.model_name_or_path and os.path.exists(args.model_name_or_path):
try:
# set global_step to global_step of last saved checkpoint from model path
checkpoint_suffix = args.model_name_or_path.split("-")[-1].split("/")[0]
global_step = int(checkpoint_suffix)
epochs_trained = global_step // (len(train_dataloader) // args.gradient_accumulation_steps)
steps_trained_in_current_epoch = global_step % (len(train_dataloader) // args.gradient_accumulation_steps)
logger.info(" Continuing training from checkpoint, will skip to saved global_step")
logger.info(" Continuing training from epoch %d", epochs_trained)
logger.info(" Continuing training from global step %d", global_step)
logger.info(" Will skip the first %d steps in the first epoch", steps_trained_in_current_epoch)
except ValueError:
logger.info(" Starting fine-tuning.")
tr_loss, logging_loss = 0.0, 0.0
model.zero_grad()
if args.evaluate_during_training:
eval_dataset = LineByLineTextDataset(args.eval_data_file, tokenizer)
train_iterator = trange(epochs_trained, int(args.num_train_epochs), desc="Epoch")
set_seed(args) # Added here for reproducibility
for _ in train_iterator:
epoch_iterator = tqdm(train_dataloader, desc="Iteration")
for step, batch in enumerate(epoch_iterator):
# Skip past any already trained steps if resuming training
if steps_trained_in_current_epoch > 0:
steps_trained_in_current_epoch -= 1
continue
inputs, labels = mask_tokens(batch, tokenizer, args)
inputs = inputs.to(args.device)
labels = labels.to(args.device)
model.train()
output = model(inputs)
# only calculating loss on masked tokens
loss_mx = labels != -100
output = output[loss_mx].view(-1, tokenizer.get_vocab_size())
labels = labels[loss_mx].view(-1)
loss = loss_fn(output, labels)
if args.n_gpus > 1:
loss = loss.mean() # mean() to average on multi-gpu parallel training
if args.gradient_accumulation_steps > 1:
loss = loss / args.gradient_accumulation_steps
if args.fp16:
with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()
else:
loss.backward()
tr_loss += loss.item()
if (step + 1) % args.gradient_accumulation_steps == 0:
if args.fp16:
torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm)
else:
torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
optimizer.step()
scheduler.step() # Update learning rate schedule
model.zero_grad()
global_step += 1
if args.logging_steps > 0 and global_step % args.logging_steps == 0:
# Log metrics
if args.evaluate_during_training: # Only evaluate when single GPU otherwise metrics may not average well
results = evaluate(args, eval_dataset, model, tokenizer, loss_fn)
for key, value in results.items():
tb_writer.add_scalar("eval_{}".format(key), value, global_step)
tb_writer.add_scalar("lr", scheduler.get_lr()[0], global_step)
tb_writer.add_scalar("loss", (tr_loss - logging_loss) / args.logging_steps, global_step)
logging_loss = tr_loss
if args.save_steps > 0 and global_step % args.save_steps == 0:
checkpoint_prefix = "checkpoint"
# Save model checkpoint
output_dir = os.path.join(args.output_dir, "{}-{}".format(checkpoint_prefix, global_step))
os.makedirs(output_dir, exist_ok=True)
model_to_save = (
model.module if hasattr(model, "module") else model
) # Take care of distributed/parallel training
torch.save(model_to_save.state_dict(), os.path.join(output_dir, WEIGHTS_NAME))
tokenizer.save(args.output_dir)
torch.save(args, os.path.join(output_dir, "training_args.bin"))
logger.info("Saved model, tokenizer and args to %s", output_dir)
_rotate_checkpoints(args, checkpoint_prefix)
torch.save(optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"))
torch.save(scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))
logger.info("Saved optimizer and scheduler states to %s", output_dir)
tb_writer.close()
return global_step, tr_loss / global_step
def load_tokenizer(output_dir):
tokenizer = ByteLevelBPETokenizer(
f"{output_dir}/vocab.json",
f"{output_dir}/merges.txt",
)
return tokenizer
def get_model():
model = ReformerLM(
num_tokens=VOCAB_SIZE,
dim=1024,
depth=1,
max_seq_len=MAX_SEQUENCE_LENGTH,
heads=8,
lsh_dropout=0.1,
ff_dropout=0.1,
post_attn_dropout=0.1,
layer_dropout=0.1, # layer dropout from 'Reducing Transformer Depth on Demand' paper
causal=False, # auto-regressive or not
bucket_size=64, # average size of qk per bucket, 64 was recommended in paper
n_hashes=4, # 4 is permissible per author, 8 is the best but slower
emb_dim=1024, # embedding factorization for further memory savings
ff_chunks=200, # number of chunks for feedforward layer, make higher if there are memory issues
attn_chunks=8, # process lsh attention in chunks, only way for memory to fit when scaling to 16k tokens
num_mem_kv=128, # persistent learned memory key values, from all-attention paper
twin_attention=False, # both branches of the reversible network will be attention
full_attn_thres=1024, # use full attention if context length is less than set value
reverse_thres=1024,
# turn off reversibility for 2x speed for sequence lengths shorter or equal to the designated value
use_scale_norm=True, # use scale norm from 'Transformers without tears' paper
one_value_head=True, # use one set of values for all heads from 'One Write-Head Is All You Need'
weight_tie=False, # tie parameters of each layer for no memory per additional depth
weight_tie_embedding=True, # use token embedding for projection of output, some papers report better results
use_full_attn=False
# only turn on this flag to override and turn on full attention for all sequence lengths. for comparison with LSH to show that it is working
,
axial_position_emb=True,
axial_position_shape=(8, 4), # the shape must multiply up to the max_seq_len (128 x 128 = 16384)
# axial_position_shape=(128, 128), # the shape must multiply up to the max_seq_len (128 x 128 = 16384)
axial_position_dims=(512, 512) # the dims must sum up to the model dimensions (512 + 512 = 1024)
)
return model
if __name__ == "__main__":
parser = argparse.ArgumentParser()
# Required parameters
parser.add_argument(
"--train_data_file", default=None, type=str, required=True, help="The input training data file (a text file)."
)
parser.add_argument(
"--output_dir",
type=str,
required=True,
help="The output directory where the model predictions and checkpoints will be written.",
)
# Other parameters
parser.add_argument(
"--eval_data_file",
default=None,
type=str,
help="An optional input evaluation data file to evaluate the perplexity on (a text file).",
)
parser.add_argument(
"--should_continue", action="store_true", help="Whether to continue from latest checkpoint in output_dir"
)
parser.add_argument(
"--model_name_or_path",
default=None,
type=str,
help="The model checkpoint for weights initialization. Leave None if you want to train a model from scratch.",
)
parser.add_argument(
"--mlm_probability", type=float, default=0.15, help="Ratio of tokens to mask for masked language modeling loss"
)
parser.add_argument(
"--config_name",
default=None,
type=str,
help="Optional pretrained config name or path if not the same as model_name_or_path. If both are None, initialize a new config.",
)
parser.add_argument(
"--tokenizer_path",
default=None,
type=str,
help="Optional pretrained tokenizer name or path if not the same as model_name_or_path. If both are None, initialize a new tokenizer.",
)
parser.add_argument("--do_train", action="store_true", help="Whether to run training.")
parser.add_argument("--do_eval", action="store_true", help="Whether to run eval on the dev set.")
parser.add_argument(
"--evaluate_during_training", action="store_true", help="Run evaluation during training at each logging step."
)
parser.add_argument("--per_gpu_train_batch_size", default=4, type=int, help="Batch size per GPU/CPU for training.")
parser.add_argument(
"--per_gpu_eval_batch_size", default=4, type=int, help="Batch size per GPU/CPU for evaluation."
)
parser.add_argument(
"--gradient_accumulation_steps",
type=int,
default=8,
help="Number of updates steps to accumulate before performing a backward/update pass.",
)
parser.add_argument("--learning_rate", default=5e-5, type=float, help="The initial learning rate for Adam.")
parser.add_argument("--weight_decay", default=0.0, type=float, help="Weight decay if we apply some.")
parser.add_argument("--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer.")
parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
parser.add_argument(
"--num_train_epochs", default=1.0, type=float, help="Total number of training epochs to perform."
)
parser.add_argument("--warmup_steps", default=0, type=int, help="Linear warmup over warmup_steps.")
parser.add_argument("--logging_steps", type=int, default=500, help="Log every X updates steps.")
parser.add_argument("--save_steps", type=int, default=500, help="Save checkpoint every X updates steps.")
parser.add_argument(
"--save_total_limit",
type=int,
default=None,
help="Limit the total amount of checkpoints, delete the older checkpoints in the output_dir, does not delete by default",
)
parser.add_argument(
"--eval_all_checkpoints",
action="store_true",
help="Evaluate all checkpoints starting with the same prefix as model_name_or_path ending and ending with step number",
)
parser.add_argument(
"--overwrite_output_dir", action="store_true", help="Overwrite the content of the output directory"
)
parser.add_argument(
"--fp16",
action="store_true",
help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit",
)
args = parser.parse_args()
if args.eval_data_file is None and args.do_eval:
raise ValueError(
"Cannot do evaluation without an evaluation data file. Either supply a file to --eval_data_file "
"or remove the --do_eval argument."
)
if args.should_continue:
sorted_checkpoints = _sorted_checkpoints(args)
if len(sorted_checkpoints) == 0:
raise ValueError("Used --should_continue but no checkpoint was found in --output_dir.")
else:
args.model_name_or_path = sorted_checkpoints[-1]
if (
os.path.exists(args.output_dir)
and os.listdir(args.output_dir)
and args.do_train
and not args.overwrite_output_dir
):
raise ValueError(
"Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome.".format(
args.output_dir
)
)
# Setup logging
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO,
)
args.device = torch.device("cuda:0" if torch.cuda.is_available() and not args.no_cuda else "cpu")
args.n_gpus = torch.cuda.device_count()
logger.warning(
"Device: %s, n_gpus: %s, 16-bits training: %s",
args.device,
args.n_gpus,
args.fp16,
)
# Set seed
args.seed = 42
set_seed(args)
tokenizer = load_tokenizer(args.tokenizer_path)
train_dataset = LineByLineTextDataset(args.train_data_file, tokenizer)
MAX_SEQUENCE_LENGTH = train_dataset.get_max_sequence_length()
tokenizer.enable_truncation(max_length=MAX_SEQUENCE_LENGTH)
VOCAB_SIZE = tokenizer.get_vocab_size()
model = get_model()
loss_fn = torch.nn.CrossEntropyLoss()
# Saving best-practices: if you use save() for the model and tokenizer, you can reload them
if args.do_train:
# Create output directory if needed
os.makedirs(args.output_dir, exist_ok=True)
global_step, tr_loss = train(args, train_dataset, model, tokenizer, loss_fn)
logger.info(" global_step = %s, average loss = %s", global_step, tr_loss)
model_to_save = (
model.module if hasattr(model, "module") else model
) # Take care of distributed/parallel training
torch.save(model_to_save.state_dict(), os.path.join(args.output_dir, WEIGHTS_NAME))
tokenizer.save(args.output_dir)
# Good practice: save your training arguments together with the trained model
torch.save(args, os.path.join(args.output_dir, "training_args.bin"))
logger.info("Model, tokenizer and args saved in %s", args.output_dir)
# Evaluation
results = {}
if args.do_eval:
tokenizer = load_tokenizer(args.output_dir)
eval_dataset = LineByLineTextDataset(args.eval_data_file, tokenizer)
tokenizer.enable_truncation(max_length=eval_dataset.get_max_sequence_length())
checkpoints = [args.output_dir]
if args.eval_all_checkpoints:
checkpoints = list(
os.path.dirname(c) for c in sorted(glob.glob(args.output_dir + "/**/" + WEIGHTS_NAME, recursive=True))
)
logger.info("Evaluate the following checkpoints: %s", checkpoints)
for checkpoint in checkpoints:
logger.info("Evaluate the following checkpoint: %s", checkpoint)
global_step = checkpoint.split("-")[-1]
prefix = checkpoint.split("/")[-1] if checkpoint.find("checkpoint") != -1 else ""
# Load a trained model and vocabulary that you have fine-tuned
model = get_model()
model.load_state_dict(torch.load(os.path.join(checkpoint, WEIGHTS_NAME)))
model.to(args.device)
if args.n_gpus > 1 and not isinstance(model, torch.nn.DataParallel):
model = torch.nn.DataParallel(model)
result = evaluate(args, eval_dataset, model, tokenizer, loss_fn, prefix=prefix)
result = dict((k + "_{}".format(global_step), v) for k, v in result.items())
results.update(result)
logger.info("Eval results: {}".format(results))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment