Last active
November 28, 2023 08:22
-
-
Save ryanpeach/9ef833745215499e77a2a92e71f89ce2 to your computer and use it in GitHub Desktop.
A Python 3 implementation of the early stopping algorithm described in the Deep Learning book by Ian Goodfellow. Untested, needs basic syntax correction.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" Python 3 implementation of Deep Learning book early stop algorithm. | |
@book{Goodfellow-et-al-2016, | |
title={Deep Learning}, | |
author={Ian Goodfellow and Yoshua Bengio and Aaron Courville}, | |
publisher={MIT Press}, | |
note={\url{http://www.deeplearningbook.org}}, | |
year={2016} | |
} | |
""" | |
import numpy as np | |
class Network(object): | |
""" Sometimes a Network object is described, use this definition. """ | |
def train(x: iterable, y: iterable) -> None: | |
raise NotImplementedError() | |
def error(x: iterable, y: iterable) -> float: | |
raise NotImplementedError() | |
def __call__(x: iterable) -> iterable: | |
raise NotImplementedError() | |
def clone() -> Network: | |
raise NotImplementedError() | |
from exceptions import Warning | |
class ConvergenceWarning(Warning): | |
""" Used to indicate an infinite loop reached max iteration | |
and the underlying function failed to converge. """ | |
pass | |
def early_stopping(theta0, (x_train, y_train), (x_valid, y_valid), n = 1, p = 100): | |
""" The early stopping meta-algorithm for determining the best amount of time to train. | |
REF: Algorithm 7.1 in deep learning book. | |
Parameters: | |
n: int; Number of steps between evaluations. | |
p: int; "patience", the number of evaluations to observe worsening validataion set. | |
theta0: Network; initial network. | |
x_train: iterable; The training input set. | |
y_train: iterable; The training output set. | |
x_valid: iterable; The validation input set. | |
y_valid: iterable; The validation output set. | |
Returns: | |
theta_prime: Network object; The output network. | |
i_prime: int; The number of iterations for the output network. | |
v: float; The validation error for the output network. | |
""" | |
# Initialize variables | |
theta = theta0.clone() # The active network | |
i = 0 # The number of training steps taken | |
j = 0 # The number of evaluations steps since last update of theta_prime | |
v = np.inf # The best evaluation error observed thusfar | |
theta_prime = theta.clone() # The best network found thusfar | |
i_prime = i # The index of theta_prime | |
while j < p: | |
# Update theta by running the training algorithm for n steps | |
for _ in range(n): | |
theta.train(x_train, y_train) | |
# Update Values | |
i += n | |
v_new = theta.error(x_valid, y_valid) | |
# If better validation error, then reset waiting time, save the network, and update the best error value | |
if v_new < v: | |
j = 0 | |
theta_prime = theta.clone() | |
i_prime = i | |
v = v_new | |
# Otherwise, update the waiting time | |
else: | |
j += 1 | |
return theta_prime, i_prime, v | |
def early_stopping_retrain((x_train, y_train), theta0, split_percent = .8, n = 1, p = 100): | |
""" Meta algorithm using early stopping to determine at what objective value we start to overfit, | |
then retraining on all the data. | |
REF: Algorithm 7.2 in deep learning book. | |
Parameters: | |
n: int; Number of steps between evaluations. | |
p: int; "patience", the number of evaluations to observe worsening validataion set. | |
split_percent: float; the percentage of subtrain to validation set length by which to split the given training sets. | |
theta0: Network; initial network. | |
x_train: iterable; training set input. | |
y_train: iterable; training set output. | |
Returns: | |
theta_prime: Network object; The output network. | |
i_prime: int; The number of iterations for the output network. | |
""" | |
# Split x_train and y_train into x_subtrain, x_valid and y_subtrain, y_valid | |
cut = int(len(x_train)*split_percent) | |
x_subtrain, x_valid = x_train[:cut], x_train[cut:] | |
y_subtrain, y_valid = y_train[:cut], y_train[cut:] | |
# Run early_stopping | |
_, i_prime, _ = early_stopping(theta0.clone(), (x_subtrain, y_subtrain), (x_valid, y_valid), n = n, p = p) | |
# Reset theta and train for the found number of steps | |
theta = theta0.clone() | |
for _ in range(i_prime): | |
theta.train(x_train, y_train) | |
return theta, i_prime | |
def early_stopping_continuous((x_train, y_train), theta0, split_percent = .8, n = 1, p = 100, max_iteration = 1e4): | |
""" Meta algorithm using early stopping to determine at what objective value we start to overfit, | |
then continue training until that value is reached. | |
REF: Algorithm 7.3 in deep learning book. | |
Parameters: | |
n: int; Number of steps between evaluations. | |
p: int; "patience", the number of evaluations to observe worsening validataion set. | |
split_percent: float; the percentage of subtrain to validation set length by which to split the given training sets. | |
theta0: Network; initial network. | |
x_train: iterable; training set input. | |
y_train: iterable; training set output. | |
max_iteration: int; maximum number of iterations to continue training, raises Exception | |
Returns: | |
theta_prime: Network object; The output network. | |
v_new: float; The validation error for the output network. | |
Raises: | |
ConvergenceWarning: Does not converge to found optimum. | |
""" | |
# Split x_train and y_train into x_subtrain, x_valid and y_subtrain, y_valid | |
cut = int(len(x_train)*split_percent) | |
x_subtrain, x_valid = x_train[:cut], x_train[cut:] | |
y_subtrain, y_valid = y_train[:cut], y_train[cut:] | |
# Run early_stopping | |
theta_prime, i_prime, v = early_stopping(theta0.clone(), (x_subtrain, y_subtrain), (x_valid, y_valid), n = n, p = p) | |
# Train on x_train and y_train until value v is reached | |
for _ in range(max_iteration): | |
# Train for n iterations | |
for _ in range(n): | |
theta_prime.train(x_train, y_train) | |
# Update error | |
v_new = theta_prime.error() | |
# If at validation error, then finish training | |
if v_new <= v: | |
return theta_prime, v_new | |
# if training never completes before max_iteration reached, raise a warning | |
raise ConvergenceWarning("early_stopping_continuous failed to converge.") | |
return theta_prime, v_new |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment