Created
December 20, 2023 06:12
-
-
Save ewof/aaf041c9a8e4301e58dffd860c9c4a47 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import argparse | |
import glob | |
import tqdm | |
import json | |
from transformers import LlamaTokenizer | |
from nltk import tokenize | |
pretrained_model_path = '/home/models/Llama-2-7b-hf' | |
tokenizer = LlamaTokenizer.from_pretrained(pretrained_model_path) | |
def add_splits(token_count, max_tokens, text, list, min_tokens): | |
if token_count >= min_tokens and token_count <= max_tokens: | |
list.append({ | |
"text": text | |
}) | |
elif token_count > max_tokens: | |
sentences = tokenize.sent_tokenize(text) | |
if len(sentences) == 1: | |
string2 = text[:len(text)//2] | |
string1 = text[len(text)//2 if len(text) % 2 == 0 | |
else (((len(text)//2))+1):] | |
else: | |
half_length = len(sentences) // 2 | |
string2 = "".join(sentences[:half_length]) | |
string1 = "".join(sentences[half_length:]) | |
add_splits(len(tokenizer(string2)["input_ids"]), | |
max_tokens, string2, list, min_tokens) | |
add_splits(len(tokenizer(string1)["input_ids"]), | |
max_tokens, string1, list, min_tokens) | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--in-dir", type=str, required=True) | |
parser.add_argument("--out-file", type=str, default="") | |
parser.add_argument("--min-tokens", type=int, default="2048") | |
parser.add_argument("--max-tokens", type=int, default="4096") | |
args = parser.parse_args() | |
files = glob.glob(f'{args.in_dir}**/*.jsonl', recursive=True) | |
new_content = [] | |
for filename in tqdm.tqdm(files): | |
with open(filename, "r") as file: | |
for line in tqdm.tqdm(file.readlines()): | |
obj = json.loads(line) | |
token_count = len(tokenizer(obj["text"])["input_ids"]) | |
if token_count >= args.min_tokens and token_count <= args.max_tokens: | |
new_content.append({ | |
"text": obj["text"] | |
}) | |
elif token_count > args.max_tokens: | |
try: | |
add_splits(token_count, args.max_tokens, | |
obj["text"], new_content, args.min_tokens) | |
except: | |
continue | |
json_lines = [json.dumps(l) for l in new_content] | |
json_data = '\n'.join(json_lines) | |
with open(args.out_file, 'w') as f: | |
f.write(json_data) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment