Skip to content

Instantly share code, notes, and snippets.

@cloneofsimo
Created August 16, 2024 05:23
Show Gist options
  • Save cloneofsimo/0ceda12353c1b6b3a303a2ece5e07284 to your computer and use it in GitHub Desktop.
Save cloneofsimo/0ceda12353c1b6b3a303a2ece5e07284 to your computer and use it in GitHub Desktop.
GPT-generated-Plots
def plot_lr_final_loss_batchsize(file_path):
    # Load the data
    data = pd.read_csv(file_path)
    
    # Extract columns that match 'val_loss/val_loss_'
    val_loss_columns = [col for col in data.columns if col.startswith('val_loss/val_loss_')]
    
    # Sort val_loss_columns by K (numeric value after 'val_loss/val_loss_') in increasing order
    val_loss_columns_sorted = sorted(val_loss_columns, key=lambda x: int(x.split('_')[-1]))
    
    # Define a function to get the last non-NaN value using the sorted columns
    def get_last_non_nan_val_sorted(row):
        # Filter values that are not NaN and return the last one if available, otherwise return NaN
        valid_values = row[val_loss_columns_sorted].dropna()
        if not valid_values.empty:
            return valid_values.iloc[-1]
        else:
            return pd.NA

    # Apply the function to each row to compute final_val_loss
    data['final_val_loss'] = data.apply(get_last_non_nan_val_sorted, axis=1)
    
    # Calculate total batch size using 'batch_size' and 'gradient_accumulation_steps'
    data['total_batch_size'] = data['batch_size'] * data['gradient_accumulation_steps']
    
    # Plotting
    plt.figure(figsize=(10, 6))
    for batch_size, group in data.groupby('total_batch_size'):
        plt.scatter(group['learning_rate'], group['final_val_loss'], label=f'Batch Size: {batch_size}')
        # Sort by learning rate to draw a connecting line
        sorted_group = group.sort_values('learning_rate')
        plt.plot(sorted_group['learning_rate'], sorted_group['final_val_loss'], marker='')

    plt.xscale('log')
    plt.yscale('log')
    plt.xlabel('Learning Rate (log scale)')
    plt.ylabel('Final Validation Loss (log scale)')
    plt.title('Log-Log Plot of Learning Rate vs. Final Validation Loss')
    plt.legend(title="Total Batch Size", bbox_to_anchor=(1.05, 1), loc='upper left')
    plt.grid(True)
    plt.show()

# Run the function with the uploaded file path
plot_lr_final_loss_batchsize('/mnt/data/wandb_export_2024-08-16T13_35_28.002+09_00.csv')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment