Last active
December 28, 2023 17:31
-
-
Save jondurbin/ffc4b9451669cc52d38b1a45b6835dc9 to your computer and use it in GitHub Desktop.
AR examples
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import re | |
import gc | |
import os | |
import glob | |
import json | |
from copy import deepcopy | |
from datasets import concatenate_datasets, Dataset | |
from transformers import AutoTokenizer | |
from huggingface_hub import snapshot_download | |
# Load a sampling of AR data. | |
dataset = ( | |
concatenate_datasets( | |
[ | |
Dataset.from_json(path) | |
for path in glob.glob("madlad-ar-sampled/*.jsonl") | |
if os.stat(path).st_size | |
] | |
) | |
.shuffle(seed=42) | |
.select(range(1000000)) | |
) | |
# Yield dataset in batches. | |
batch_size = 1000 | |
def batch_iterator(): | |
for i in range(0, len(dataset), batch_size): | |
yield dataset[i : i + batch_size]["text"] | |
gc.collect() | |
# Download and initialize the original mistral tokenizer. | |
snapshot_download( | |
repo_id="mistralai/Mistral-7B-v0.1", | |
local_dir="mistral-7b-v0.1", | |
allow_patterns=["*tokenizer*"], | |
) | |
tokenizer = AutoTokenizer.from_pretrained("mistral-7b-v0.1", use_fast=True) | |
# The train_new_from_iterator method ignores previous tokens, so the vocab_size param should | |
# be the number of new tokens desired, not total tokens, in this case 2*16. You can limit | |
# the maximum size of new tokens with max_token_length as well, which helps prevent commonly | |
# repeated URLs, disclaimers, and other noise from entering the vocab. | |
new_vocab_size = 2**16 - len(tokenizer) | |
new_tokenizer = tokenizer.train_new_from_iterator( | |
batch_iterator(), vocab_size=new_vocab_size, max_token_length=24 | |
) | |
new_tokenizer.save_pretrained("mistral-7b-tokenizer-ar-temp") | |
# Load the original tokenizer. | |
with open("mistral-7b/tokenizer.json") as f: | |
original = json.load(f) | |
# Load the updated tokenizer we just trained. | |
with open("mistral-7b-tokenizer-ar-temp/tokenizer.json") as f: | |
append = json.load(f) | |
def merge_tokenizer(original_data: dict, append_data: dict): | |
original_vocab = original_data["model"]["vocab"] | |
append_vocab = append_data["model"]["vocab"] | |
vocab_out = deepcopy(original_vocab) | |
data_out = deepcopy(original_data) | |
idx = max(vocab_out.values()) | |
# Append the new vocab tokens, ignoring numeric values since they decrease math/reasoning performance. | |
for token in append_vocab.keys(): | |
if token not in original_vocab and not ( | |
re.search(r"[0-9]", token) and re.match(r"^([^\w]|[0-9])+$", token) | |
): | |
idx += 1 | |
vocab_out[token] = idx | |
# Update merges. | |
merges_out = [] | |
for candidate, piece_id in vocab_out.items(): | |
for i in range(1, len(candidate)): | |
left, right = candidate[:i], candidate[i:] | |
left_id = vocab_out.get(left, None) | |
right_id = vocab_out.get(right, None) | |
if left_id is not None and right_id is not None: | |
if ( | |
re.search(r"[0-9]", left) | |
and re.match(r"^([^\w]|[0-9])+$", left) | |
and re.search(r"[0-9]", right) | |
and re.match(r"^([^\w]|[0-9])+$", right) | |
): | |
continue | |
merges_out += [f"{left} {right}"] | |
data_out["model"]["vocab"] = vocab_out | |
data_out["model"]["merges"] = merges_out | |
tokenizer.save_pretrained("mistral-7b-tokenizer-ar") | |
with open("mistral-7b-tokenizer-ar/tokenizer.json", "w") as f: | |
json.dump(data_out, f, ensure_ascii=False, indent=2) | |
merge_tokenizer(original_data=original, append_data=append) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import glob | |
import subprocess | |
import datasets | |
import os | |
from data_selection import HashedNgramDSIR | |
from huggingface_hub import snapshot_download | |
from loguru import logger | |
# Download the AR data, as well as a small sampling of EN data. | |
logger.info("Downloading data files...") | |
snapshot_download( | |
repo_id="allenai/madlad-400", | |
local_dir="madlad-400", | |
cache_dir=".cache", | |
local_dir_use_symlinks=False, | |
allow_patterns=[ | |
"data/ar/ar_clean_*.gz", | |
"data/en/en_clean_000*.gz", | |
], | |
repo_type="dataset", | |
) | |
logger.info("Extracting gzips...") | |
current_path = os.getcwd() | |
for path in map(str, glob.glob("madlad-400/data/*/*clean*.gz", recursive=True)): | |
logger.info(f"Extracting: {path}") | |
os.chdir(os.path.dirname(os.path.abspath(path))) | |
subprocess.run(["gunzip", path]) | |
# Sample AR datasets. | |
logger.info("Sampling AR datasets...") | |
# Need to initialize an empty target dataset file - if you want to extract | |
# data with a relative "importance", you can populate this with the target data. | |
with open("madlad-400-ar-sample.jsonl", "a+") as outfile: | |
... | |
# Filter down the AR dataset via DSIR to N documents. | |
ar_datasets = glob.glob("madlad-400/data/ar/ar_clean_*") | |
dsir = HashedNgramDSIR( | |
ar_datasets, | |
["madlad-400-ar-sample.jsonl"], | |
cache_dir=".cache/dsir", | |
) | |
dsir.fit_importance_estimator(num_tokens_to_fit="auto") | |
dsir.compute_importance_weights() | |
dsir.resample( | |
out_dir="madlad-ar-sampled", | |
num_to_sample=5000000, | |
cache_dir=".cache/resampled", | |
) | |
# Sample EN datasets at a much lower ratio, just to help maintain base model capabilities. | |
logger.info("Sampling EN datasets...") | |
with open("madlad-400-en-sample.jsonl", "a+") as outfile: | |
... | |
en_datasets = glob.glob("madlad-400/data/en/en_clean_*") | |
dsir = HashedNgramDSIR( | |
en_datasets, | |
["madlad-400-en-sample.jsonl"], | |
cache_dir=".cache/dsir-en", | |
) | |
dsir.fit_importance_estimator(num_tokens_to_fit="auto") | |
dsir.compute_importance_weights() | |
dsir.resample( | |
out_dir="madlad-en-sampled", | |
num_to_sample=500000, | |
cache_dir=".cache/resampled-en", | |
) | |
# Load and unify the various EN/AR files. | |
logger.info("Unifying dataset...") | |
sample_files = list(glob.glob("madlad-ar-sampled/*.jsonl")) + list( | |
glob.glob("madlad-en-sampled/*.jsonl") | |
) | |
all_datasets = [] | |
for path in sample_files: | |
if os.stat(path).st_size: | |
all_datasets.append(datasets.Dataset.from_json(path)) | |
# Combine everything. | |
datasets.concatenate_datasets(all_datasets).shuffle(seed=42).to_parquet( | |
"madlad-pretrain-sample-combined.parquet" | |
) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment