Last active
March 14, 2022 10:08
-
-
Save ubless607/b436e980dfeb7b959b9d490f3f13db35 to your computer and use it in GitHub Desktop.
A simple function for plotting a learning curve of the model
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def plot_learning_curve(log_df, | |
metric_name='loss', | |
rolling=False, | |
window_size=5, | |
ylim=(None, None), **kwargs): | |
''' | |
A simple function for plotting a learning curve of the model | |
Args: | |
log_df: input pandas Dataframe | |
metric_name: name of the metric to plot | |
ylim: y-axis limit, Tuple of (bottom, top) | |
rolling: Defaults to False. If set to True, plot a moving averaged graph in the second figure | |
window_size: size of the moving window | |
Reference: | |
https://pandas.pydata.org/docs/reference/api/pandas.DataFrame.rolling.html | |
Author: | |
@ubless607, @yubin8773 | |
''' | |
# Data from the log.csv | |
epochs = np.arange(log_df.epoch.iloc[0] + 1, log_df.epoch.iloc[-1] + 2, 1, dtype=np.uint32) | |
plt.style.use('seaborn-whitegrid') | |
fig1 = plt.figure(figsize=kwargs.get('fig_size', (8, 4))) | |
plt.title(f'Learning Curves ({metric_name})') | |
plt.xlabel('Epoch') | |
plt.ylabel(f'{metric_name}') | |
if ylim[0] is not None: | |
plt.ylim(bottom=ylim[0]) | |
if ylim[1] is not None: | |
plt.ylim(top=ylim[1]) | |
plt.plot(epochs, log_df[f'{metric_name}'], '-', label='Training') | |
plt.plot(epochs, log_df[f'val_{metric_name}'], '-', label='Validation') | |
plt.legend() | |
plt.tight_layout() | |
plt.show() | |
if rolling: | |
fig2 = plt.figure(figsize=kwargs.get('fig_size', (8, 4))) | |
loss_mavg = log_df[f'{metric_name}'].rolling(window=window_size).mean() | |
val_loss_mavg = log_df[f'val_{metric_name}'].rolling(window=window_size).mean() | |
plt.title(f'Learning Curves ({metric_name}) - moving average') | |
plt.xlabel(f'Epoch') | |
plt.ylabel(f'{metric_name}') | |
if ylim[0] is not None: | |
plt.ylim(bottom=ylim[0]) | |
if ylim[1] is not None: | |
plt.ylim(top=ylim[1]) | |
plt.plot(epochs, loss_mavg, '-', label='Training') | |
plt.plot(epochs, val_loss_mavg, '-', label='Validation') | |
plt.legend() | |
plt.tight_layout() | |
plt.show() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment