Created
September 6, 2023 02:31
-
-
Save pszemraj/bd089658cb96c4e9fbd4129703638a8f to your computer and use it in GitHub Desktop.
train tokenizer hf tokenizers - WIP script
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import logging | |
import gzip | |
from pathlib import Path | |
import fire | |
from tqdm import tqdm | |
from tokenizers import ( | |
Tokenizer, | |
decoders, | |
models, | |
normalizers, | |
pre_tokenizers, | |
trainers, | |
) | |
import datasets | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
# Tokenizer initialization | |
def initialize_tokenizer(model_type="Unigram", vocab_size=20000): | |
tokenizer = Tokenizer(getattr(models, model_type)()) | |
tokenizer.normalizer = normalizers.NFKC() | |
tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel() | |
tokenizer.decoder = decoders.ByteLevel() | |
trainer = trainers.UnigramTrainer( | |
vocab_size=vocab_size, | |
initial_alphabet=pre_tokenizers.ByteLevel.alphabet(), | |
special_tokens=["<PAD>", "<BOS>", "<EOS>"], | |
) | |
return tokenizer, trainer | |
# Train tokenizer from list | |
def train_from_list(tokenizer, trainer, data): | |
tokenizer.train_from_iterator(data, trainer=trainer) | |
# Train tokenizer using the HuggingFace datasets library | |
def train_from_datasets(tokenizer, trainer, dataset_name, split): | |
dataset = datasets.load_dataset(dataset_name, split=split) | |
def batch_iterator(batch_size=1000): | |
for i in tqdm(range(0, len(dataset), batch_size), desc="Loading Batches"): | |
yield dataset[i : i + batch_size]["text"] | |
tokenizer.train_from_iterator( | |
batch_iterator(), trainer=trainer, length=len(dataset) | |
) | |
# Train tokenizer using gzip files | |
def train_from_gzip(tokenizer, trainer, files): | |
def gzip_iterator(): | |
for path in files: | |
with gzip.open(path, "rt") as f: | |
for line in f: | |
yield line | |
tokenizer.train_from_iterator(gzip_iterator(), trainer=trainer) | |
# Main function | |
def main( | |
model_type="Unigram", | |
vocab_size=20000, | |
data_type="list", | |
data_path=None, | |
dataset_name=None, | |
split="train+test+validation", | |
): | |
tokenizer, trainer = initialize_tokenizer(model_type, vocab_size) | |
if data_type == "list": | |
with open(data_path, "r") as f: | |
data = [line.strip() for line in f] | |
train_from_list(tokenizer, trainer, data) | |
elif data_type == "datasets": | |
train_from_datasets(tokenizer, trainer, dataset_name, split) | |
elif data_type == "gzip": | |
files = list(Path(data_path).glob("*.gz")) | |
train_from_gzip(tokenizer, trainer, files) | |
else: | |
logger.error(f"Unsupported data_type: {data_type}") | |
return | |
# Saving tokenizer | |
output_path = Path("tokenizers") / model_type | |
output_path.mkdir(parents=True, exist_ok=True) | |
tokenizer.save(str(output_path / "tokenizer.json")) | |
logger.info(f"Tokenizer saved at: {output_path / 'tokenizer.json'}") | |
if __name__ == "__main__": | |
fire.Fire(main) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment