-
-
Save krammnic/b5035a339a6684a13eca6c037c6b82b6 to your computer and use it in GitHub Desktop.
compress tokenizer
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from torchtune.data import Message | |
from torchtune.modules.transforms.tokenizers import HuggingFaceModelTokenizer | |
TOKENIZER_CONFIG_PATH = "tokenizer_config_gemma.json" | |
GENERATION_CONFIG_PATH = "generation_config_gemma.json" | |
TOKENIZER_PATH = "tokenizer_gemma_cropped.json" | |
def test_huggingface_model_tokenizer(): | |
try: | |
# Initialize tokenizer | |
model_tokenizer = HuggingFaceModelTokenizer( | |
tokenizer_json_path=str(TOKENIZER_PATH), | |
tokenizer_config_json_path=str(TOKENIZER_CONFIG_PATH), | |
generation_config_path=str(GENERATION_CONFIG_PATH), | |
) | |
# Create test messages | |
messages = [ | |
Message( | |
role="user", | |
content="hello there", | |
masked=False, | |
), | |
Message( | |
role="assistant", | |
content="hi", | |
masked=False, | |
), | |
Message( | |
role="user", | |
content="whatsup?", | |
masked=False, | |
), | |
] | |
messages[0].masked = True | |
tokens, mask = model_tokenizer.tokenize_messages(messages) | |
try: | |
# Test no mask case | |
assert tokens[:-4] == [ | |
2, | |
106, | |
1645, | |
108, | |
17534, | |
1104, | |
107, | |
108, | |
106, | |
2516, | |
108, | |
544, | |
107, | |
108, | |
106, | |
1645, | |
108, | |
5049, | |
15827, | |
235336, | |
107, | |
108, | |
] | |
assert mask[:-4] == [True] * 8 + [False] * 14 | |
except AssertionError: | |
return False | |
return True | |
except Exception as e: | |
print(f"Test failed with error: {e}") | |
return False | |
def batch_remove_lines(file_path, batch_size=100): | |
with open(file_path, 'r') as f: | |
lines = f.readlines() | |
total_lines = len(lines) | |
kept_lines = set(range(total_lines)) # Track kept lines by index | |
modified = False | |
# Process batches from the end (avoids index shifting) | |
for start_idx in range(0, total_lines, batch_size): | |
end_idx = min(start_idx + batch_size, total_lines) | |
batch_indices = set(range(start_idx, end_idx)) | |
# Skip if no lines left in batch | |
if not batch_indices.intersection(kept_lines): | |
continue | |
# Create test content (exclude current batch) | |
test_content = "".join( | |
line for i, line in enumerate(lines) | |
if i not in batch_indices and i in kept_lines | |
) | |
# Write test version | |
with open(file_path, 'w') as f: | |
f.write(test_content) | |
# Run test | |
if test_huggingface_model_tokenizer(): | |
print(f"✅ Tests passed - Removed batch {start_idx}-{end_idx - 1}") | |
kept_lines -= batch_indices # Permanently remove batch | |
modified = True | |
else: | |
print(f"❌ Tests failed - Kept batch {start_idx}-{end_idx - 1}") | |
# Save final content | |
if modified: | |
final_content = "".join( | |
line for i, line in enumerate(lines) | |
if i in kept_lines | |
) | |
with open(file_path, 'w') as f: | |
f.write(final_content) | |
print(f"\nFinal: Kept {len(kept_lines)}/{total_lines} lines") | |
# Execute | |
batch_remove_lines("tokenizer_gemma_cropped.json", batch_size=500) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment