Skip to content

Instantly share code, notes, and snippets.

@ewof
Created December 20, 2023 06:12
Show Gist options
  • Save ewof/aaf041c9a8e4301e58dffd860c9c4a47 to your computer and use it in GitHub Desktop.
Save ewof/aaf041c9a8e4301e58dffd860c9c4a47 to your computer and use it in GitHub Desktop.
import argparse
import glob
import tqdm
import json
from transformers import LlamaTokenizer
from nltk import tokenize
pretrained_model_path = '/home/models/Llama-2-7b-hf'
tokenizer = LlamaTokenizer.from_pretrained(pretrained_model_path)
def add_splits(token_count, max_tokens, text, list, min_tokens):
if token_count >= min_tokens and token_count <= max_tokens:
list.append({
"text": text
})
elif token_count > max_tokens:
sentences = tokenize.sent_tokenize(text)
if len(sentences) == 1:
string2 = text[:len(text)//2]
string1 = text[len(text)//2 if len(text) % 2 == 0
else (((len(text)//2))+1):]
else:
half_length = len(sentences) // 2
string2 = "".join(sentences[:half_length])
string1 = "".join(sentences[half_length:])
add_splits(len(tokenizer(string2)["input_ids"]),
max_tokens, string2, list, min_tokens)
add_splits(len(tokenizer(string1)["input_ids"]),
max_tokens, string1, list, min_tokens)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--in-dir", type=str, required=True)
parser.add_argument("--out-file", type=str, default="")
parser.add_argument("--min-tokens", type=int, default="2048")
parser.add_argument("--max-tokens", type=int, default="4096")
args = parser.parse_args()
files = glob.glob(f'{args.in_dir}**/*.jsonl', recursive=True)
new_content = []
for filename in tqdm.tqdm(files):
with open(filename, "r") as file:
for line in tqdm.tqdm(file.readlines()):
obj = json.loads(line)
token_count = len(tokenizer(obj["text"])["input_ids"])
if token_count >= args.min_tokens and token_count <= args.max_tokens:
new_content.append({
"text": obj["text"]
})
elif token_count > args.max_tokens:
try:
add_splits(token_count, args.max_tokens,
obj["text"], new_content, args.min_tokens)
except:
continue
json_lines = [json.dumps(l) for l in new_content]
json_data = '\n'.join(json_lines)
with open(args.out_file, 'w') as f:
f.write(json_data)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment