Last active
July 18, 2019 14:05
-
-
Save ben0it8/10247e93f1049c9ade08bace43f2ba1d to your computer and use it in GitHub Desktop.
prepare training and eval loops
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from ignite.engine import Engine, Events | |
from ignite.metrics import RunningAverage, Accuracy | |
from ignite.handlers import ModelCheckpoint | |
from ignite.contrib.handlers import CosineAnnealingScheduler, PiecewiseLinear, create_lr_scheduler_with_warmup, ProgressBar | |
import torch.nn.functional as F | |
from pytorch_transformers.optimization import AdamW | |
# Bert optimizer | |
optimizer = AdamW(model.parameters(), lr=finetuning_config.lr, correct_bias=False) | |
def update(engine, batch): | |
"update function for training" | |
model.train() | |
inputs, labels = (t.to(finetuning_config.device) for t in batch) | |
inputs = inputs.transpose(0, 1).contiguous() # [S, B] | |
_, loss = model(inputs, | |
clf_tokens_mask = (inputs == tokenizer.vocab[processor.CLS]), | |
clf_labels=labels) | |
loss = loss / finetuning_config.gradient_acc_steps | |
loss.backward() | |
torch.nn.utils.clip_grad_norm_(model.parameters(), finetuning_config.max_norm) | |
if engine.state.iteration % finetuning_config.gradient_acc_steps == 0: | |
optimizer.step() | |
optimizer.zero_grad() | |
return loss.item() | |
def inference(engine, batch): | |
"update function for evaluation" | |
model.eval() | |
with torch.no_grad(): | |
batch, labels = (t.to(finetuning_config.device) for t in batch) | |
inputs = batch.transpose(0, 1).contiguous() | |
logits = model(inputs, | |
clf_tokens_mask = (inputs == tokenizer.vocab[processor.CLS]), | |
padding_mask = (batch == tokenizer.vocab[processor.PAD])) | |
return logits, labels | |
trainer = Engine(update) | |
evaluator = Engine(inference) | |
# add metric to evaluator | |
Accuracy().attach(evaluator, "accuracy") | |
# add evaluator to trainer: eval on valid set after each epoch | |
@trainer.on(Events.EPOCH_COMPLETED) | |
def log_validation_results(engine): | |
evaluator.run(valid_dl) | |
print(f"validation epoch: {engine.state.epoch} acc: {100*evaluator.state.metrics['accuracy']}") | |
# lr schedule: linearly warm-up to lr and then to zero | |
scheduler = PiecewiseLinear(optimizer, 'lr', [(0, 0.0), (finetuning_config.n_warmup, finetuning_config.lr), | |
(len(train_dl)*finetuning_config.n_epochs, 0.0)]) | |
trainer.add_event_handler(Events.ITERATION_STARTED, scheduler) | |
# add progressbar with loss | |
RunningAverage(output_transform=lambda x: x).attach(trainer, "loss") | |
ProgressBar(persist=True).attach(trainer, metric_names=['loss']) | |
# save checkpoints and finetuning config | |
checkpoint_handler = ModelCheckpoint(finetuning_config.log_dir, 'finetuning_checkpoint', | |
save_interval=1, require_empty=False) | |
trainer.add_event_handler(Events.EPOCH_COMPLETED, checkpoint_handler, {'imdb_model': model}) | |
# save config to logdir | |
torch.save(finetuning_config, os.path.join(finetuning_config.log_dir, 'fine_tuning_args.bin')) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment