Last active
December 13, 2018 18:27
-
-
Save foowaa/917e3bdc1f963480c729ea9f528dce1f to your computer and use it in GitHub Desktop.
training procedure
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
''' | |
num_epochs: 运行epoch的轮数 | |
train: 训练数据 | |
dev: 验证集数据 | |
evalp: 几轮进行一次验证 | |
model: 模型 | |
metric_best: 记录最佳的metric | |
metric_stop: 训练停止的metric | |
cnt_stop: 训练到dev几次不能超过metric_best就停止 | |
''' | |
from tqdm import tqdm | |
pbar_epochs = tqdm(range(num_epochs)) | |
for epoch in pbar_epochs: | |
pbar_epochs.set_description("Epoch:{:d}.".format(epoch)) | |
pbar = tqdm(train) | |
for data in pbar: | |
model.train_batch(data) | |
pbar.set_description("Loss:{:.4f}.".format(model.loss)) | |
pbar.refresh() | |
if (epoch+1) % evalp == 0: | |
metric = model.evaluate(dev) | |
if metric >= metric_best: | |
metric_best = metric | |
cnt = 0 | |
else: | |
cnt += 1 | |
if metric > metric_stop: | |
break | |
if cnt > cnt_stop: | |
break | |
pbar_epochs.refresh() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment