Last active
January 17, 2024 02:36
-
-
Save pszemraj/ec852985f5937ef92d9534796bbd183c to your computer and use it in GitHub Desktop.
loads a Hugging Face Transformers tokenizer, checks for essential special tokens, adds them if necessary
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from transformers import AutoTokenizer | |
def load_and_ensure_tokens(model_name): | |
# Load the tokenizer | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
# Essential special tokens with their default values | |
essential_tokens = { | |
"pad_token": "<pad>", | |
"eos_token": "<eos>", | |
"bos_token": "<bos>", | |
"mask_token": "<mask>", | |
"unk_token": "<unk>", | |
"sep_token": "<sep>", | |
} | |
# Identify and add missing tokens | |
missing_tokens = { | |
k: v for k, v in essential_tokens.items() if getattr(tokenizer, k, None) is None | |
} | |
if missing_tokens: | |
tokenizer.add_special_tokens( | |
{"additional_special_tokens": list(missing_tokens.values())} | |
) | |
# Explicitly set each essential token if it is still None | |
for token_name, token_value in essential_tokens.items(): | |
if getattr(tokenizer, token_name, None) is None: | |
setattr(tokenizer, token_name, token_value) | |
return tokenizer | |
# Example usage | |
tokenizer = load_and_ensure_tokens("tiiuae/falcon-180B") | |
# Test: Print out special tokens to verify | |
print("Special tokens:", tokenizer.special_tokens_map) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment