Created
August 3, 2023 09:59
-
-
Save twobob/0a366b0f9628e2cc0aba9cb93d4acce3 to your computer and use it in GitHub Desktop.
analyse time logs for token outputs on babyllama
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import matplotlib.pyplot as plt | |
# Function to calculate tokens per second, filtering out zero time differences | |
def calculate_tokens_per_second_filtered(cumulative_time): | |
time_diffs_seconds = np.diff(cumulative_time) / 1000 | |
# Filtering out zero time differences | |
time_diffs_seconds_filtered = time_diffs_seconds[time_diffs_seconds != 0] | |
tokens_per_second = 1 / time_diffs_seconds_filtered | |
return tokens_per_second | |
# Function to process a log file and return the tokens-per-second values for each run separately | |
def process_log_file_per_run_filtered(file_path): | |
runs_tokens_per_second = [] | |
with open(file_path, 'r') as file: | |
file.readline() # Skipping the first blank line | |
for line in file: | |
timestamps = [float(time) for time in line.strip().split(',')[1:] if time.strip() != ''] # Skipping the header and handling empty strings | |
cumulative_time = np.cumsum(np.diff(timestamps)) | |
tokens_per_second = calculate_tokens_per_second_filtered(cumulative_time) | |
runs_tokens_per_second.append(tokens_per_second) | |
return runs_tokens_per_second | |
# Function to plot scatter graphs with best-fit curves for each run in a dataset | |
def plot_scatter_with_best_fit_per_run(runs_tokens_per_second, title): | |
plt.figure(figsize=(12, 6)) | |
for i, tokens_per_second in enumerate(runs_tokens_per_second): | |
x_data = np.arange(len(tokens_per_second)) | |
params = np.polyfit(x_data, tokens_per_second, deg=2) # Fitting a polynomial of degree 2 | |
best_fit_curve = np.polyval(params, x_data) | |
plt.scatter(x_data, tokens_per_second, s=5, alpha=0.5) | |
plt.plot(x_data, best_fit_curve, linestyle='--', alpha=0.5) | |
plt.xlabel('Time Steps') | |
plt.ylabel('Tokens Per Second') | |
plt.title(title) | |
plt.grid(True) | |
plt.show() | |
# Usage Example | |
file_path_15M = '/path/to/15M_file.csv' | |
file_path_110M = '/path/to/110M_file.csv' | |
file_path_42M = '/path/to/42M_file.csv' | |
runs_tokens_per_second_15M = process_log_file_per_run_filtered(file_path_15M) | |
runs_tokens_per_second_110M = process_log_file_per_run_filtered(file_path_110M) | |
runs_tokens_per_second_42M = process_log_file_per_run_filtered(file_path_42M) | |
plot_scatter_with_best_fit_per_run(runs_tokens_per_second_15M, '15M Tokens: Tokens Per Second vs Time (Per Run)') | |
plot_scatter_with_best_fit_per_run(runs_tokens_per_second_110M, '110M Tokens: Tokens Per Second vs Time (Per Run)') | |
plot_scatter_with_best_fit_per_run(runs_tokens_per_second_42M, '42M Tokens: Tokens Per Second vs Time (Per Run)') |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
inside the token loop
fprintf(timelog_file, "%ld,", time_in_ms());