def plot_lr_final_loss_batchsize(file_path):
# Load the data
data = pd.read_csv(file_path)
# Extract columns that match 'val_loss/val_loss_'
val_loss_columns = [col for col in data.columns if col.startswith('val_loss/val_loss_')]
# Sort val_loss_columns by K (numeric value after 'val_loss/val_loss_') in increasing order
val_loss_columns_sorted = sorted(val_loss_columns, key=lambda x: int(x.split('_')[-1]))
# Define a function to get the last non-NaN value using the sorted columns
def get_last_non_nan_val_sorted(row):
# Filter values that are not NaN and return the last one if available, otherwise return NaN
valid_values = row[val_loss_columns_sorted].dropna()
if not valid_values.empty:
return valid_values.iloc[-1]
else:
return pd.NA
# Apply the function to each row to compute final_val_loss
data['final_val_loss'] = data.apply(get_last_non_nan_val_sorted, axis=1)
# Calculate total batch size using 'batch_size' and 'gradient_accumulation_steps'
data['total_batch_size'] = data['batch_size'] * data['gradient_accumulation_steps']
# Plotting
plt.figure(figsize=(10, 6))
for batch_size, group in data.groupby('total_batch_size'):
plt.scatter(group['learning_rate'], group['final_val_loss'], label=f'Batch Size: {batch_size}')
# Sort by learning rate to draw a connecting line
sorted_group = group.sort_values('learning_rate')
plt.plot(sorted_group['learning_rate'], sorted_group['final_val_loss'], marker='')
plt.xscale('log')
plt.yscale('log')
plt.xlabel('Learning Rate (log scale)')
plt.ylabel('Final Validation Loss (log scale)')
plt.title('Log-Log Plot of Learning Rate vs. Final Validation Loss')
plt.legend(title="Total Batch Size", bbox_to_anchor=(1.05, 1), loc='upper left')
plt.grid(True)
plt.show()
# Run the function with the uploaded file path
plot_lr_final_loss_batchsize('/mnt/data/wandb_export_2024-08-16T13_35_28.002+09_00.csv')
Created
August 16, 2024 05:23
-
-
Save cloneofsimo/0ceda12353c1b6b3a303a2ece5e07284 to your computer and use it in GitHub Desktop.
GPT-generated-Plots
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment