Last active
February 3, 2020 01:22
-
-
Save matpalm/555f99cee4391e4afdacfba05bb29637 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3 | |
import argparse | |
import model as m | |
from tensorflow.keras.callbacks import * | |
import data as d | |
import tensorflow as tf | |
import os | |
from lr_finder import LearningRateFinder | |
tf.config.optimizer.set_jit(True) | |
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2" | |
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) | |
parser.add_argument('--train-tf-record-glob', type=str, required=True) | |
parser.add_argument('--num-batches', type=int, default=100) | |
parser.add_argument('--batch-size', type=int, default=64) | |
parser.add_argument('--initial-learning-rate', type=float, default=1e-10) | |
parser.add_argument('--final-learning-rate', type=float, default=1e-1) | |
parser.add_argument('--shuffle-buffer-size', type=int, default=64) | |
parser.add_argument('--expit-squash', type=float, default=1.0) | |
parser.add_argument('--plus-one-weight', type=float, default=5.0) | |
parser.add_argument('--self-labelled-weight', type=float, default=1.0) | |
opts = parser.parse_args() | |
train_dataset = d.dataset_from_tfrecord(opts.train_tf_record_glob, | |
batch_size=opts.batch_size, | |
plus_one_weight=opts.plus_one_weight, | |
self_labelled_weight=opts.self_labelled_weight, | |
expit_squash=opts.expit_squash, | |
training=True, | |
shuffle_buffer=opts.shuffle_buffer_size) | |
model = m.construct_model(learning_rate=1e-4) | |
finder = LearningRateFinder(model) | |
finder.find(train_dataset, | |
initial_learning_rate=opts.initial_learning_rate, | |
final_learning_rate=opts.final_learning_rate, | |
num_batches=opts.num_batches) | |
finder.export_plot(fname="/tmp/foo.png") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import math | |
from matplotlib import pyplot as plt | |
import tensorflow.keras.backend as K | |
class LearningRateFinder: | |
def __init__(self, model): | |
self.model = model | |
self.losses = [] | |
self.learning_rates = [] | |
def find(self, train_dataset, num_batches, initial_learning_rate, final_learning_rate): | |
learning_rate_ratio = final_learning_rate / initial_learning_rate | |
self.learning_rate_multiplier = learning_rate_ratio ** (1 / num_batches) | |
print("num_batches", num_batches) | |
print("self.learning_rate_multiplier", self.learning_rate_multiplier) | |
K.set_value(self.model.optimizer.lr, initial_learning_rate) | |
print("initial_learning_rate", initial_learning_rate) | |
# TODO: we make the assumption that there is enough data to | |
# take at _least_ num_batches | |
for i, (imgs, labels, _weights) in enumerate(train_dataset.take(num_batches)): | |
# TODO: no sample_weight on train_on_batch in this keras version? | |
loss, _accuracy = self.model.train_on_batch(imgs, labels) | |
if math.isnan(loss): | |
break | |
self.losses.append(loss) | |
learning_rate = K.get_value(self.model.optimizer.learning_rate) | |
self.learning_rates.append(learning_rate) | |
print("%d/%d learning_rate=%s loss=%s" % (i, num_batches, learning_rate, loss)) | |
learning_rate *= self.learning_rate_multiplier | |
K.set_value(self.model.optimizer.lr, learning_rate) | |
def export_plot(self, fname): | |
plt.ylabel("loss") | |
plt.xlabel("learning rate") | |
plt.plot(self.learning_rates, self.losses) | |
plt.xscale('log') | |
plt.savefig(fname) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment