Created
November 11, 2024 21:27
-
-
Save monk1337/a4ef802751c20c6748138206c193ed9a to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from transformers import AutoTokenizer | |
import numpy as np | |
from tqdm import tqdm | |
def analyze_token_sizes(dataset, tokenizer): | |
""" | |
Analyze token sizes for all text fields in the dataset | |
""" | |
# Initialize lists to store token counts | |
token_counts = [] | |
# Process each example in the dataset | |
for example in tqdm(dataset['train'], desc="Analyzing tokens"): | |
# Combine all text fields | |
combined_text = f"Instruction: {example['instruction']}\nInput: {example['input']}\nOutput: {example['output']}" | |
# Tokenize and count | |
tokens = tokenizer.encode(combined_text) | |
token_counts.append(len(tokens)) | |
# Calculate statistics | |
stats = { | |
'min_tokens': min(token_counts), | |
'max_tokens': max(token_counts), | |
'avg_tokens': np.mean(token_counts), | |
'median_tokens': np.median(token_counts), | |
'std_tokens': np.std(token_counts), | |
'90th_percentile': np.percentile(token_counts, 90), | |
'total_samples': len(token_counts) | |
} | |
return stats, token_counts | |
# Run analysis | |
stats, token_counts = analyze_token_sizes(clinicaltrail, tokenizer) | |
# Print results | |
print("\nToken Size Analysis:") | |
print(f"Minimum tokens: {stats['min_tokens']:,}") | |
print(f"Maximum tokens: {stats['max_tokens']:,}") | |
print(f"Average tokens: {stats['avg_tokens']:.2f}") | |
print(f"Median tokens: {stats['median_tokens']:.2f}") | |
print(f"Standard deviation: {stats['std_tokens']:.2f}") | |
print(f"90th percentile: {stats['90th_percentile']:.2f}") | |
print(f"Total samples analyzed: {stats['total_samples']:,}") | |
# Optional: Create histogram of token counts | |
import matplotlib.pyplot as plt | |
plt.figure(figsize=(10, 6)) | |
plt.hist(token_counts, bins=50, edgecolor='black') | |
plt.title('Distribution of Token Counts') | |
plt.xlabel('Number of Tokens') | |
plt.ylabel('Frequency') | |
plt.grid(True, alpha=0.3) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment