Skip to content

Instantly share code, notes, and snippets.

@CoffeeVampir3
Created April 20, 2024 02:21
Show Gist options
  • Save CoffeeVampir3/85acbda7ac5db055fa4a5aa04cae2b0d to your computer and use it in GitHub Desktop.
Save CoffeeVampir3/85acbda7ac5db055fa4a5aa04cae2b0d to your computer and use it in GitHub Desktop.
from transformers import AutoTokenizer
import json
import sys
model = "/home/blackroot/Desktop/llama3-8b/llama-3-8b"
max_tokens = 8192
def count_tokens_hf(text: str, model_name: str) -> int:
tokenizer = AutoTokenizer.from_pretrained(model_name)
encoded_input = tokenizer.encode(text)
num_tokens = len(encoded_input)
return num_tokens
def chunk_text_hf(text: str, model_name: str, max_tokens: int) -> list:
tokenizer = AutoTokenizer.from_pretrained(model_name)
encoded_input = tokenizer.encode(text)
chunks = []
start_idx = 0
while start_idx < len(encoded_input):
end_idx = min(start_idx + max_tokens, len(encoded_input))
if end_idx < len(encoded_input):
# Find the nearest stopping punctuation (., ?, or !)
while end_idx > start_idx and tokenizer.decode(encoded_input[end_idx]) not in [".", "?", "!"]:
end_idx -= 1
# If no stopping punctuation found, include the entire chunk
if end_idx == start_idx:
end_idx = min(start_idx + max_tokens, len(encoded_input))
chunk = tokenizer.decode(encoded_input[start_idx:end_idx+1])
chunks.append(chunk)
start_idx = end_idx + 1
return chunks
def process_file(file_path: str, model_name: str, max_tokens: int) -> None:
try:
with open(file_path, 'r', encoding='utf-8') as file:
text = file.read()
chunks = chunk_text_hf(text, model_name, max_tokens)
output_file = file_path.rsplit('.', 1)[0] + '.jsonl'
with open(output_file, 'w', encoding='utf-8') as jsonl_file:
for chunk in chunks:
chunk_size = count_tokens_hf(chunk, model_name)
print(chunk_size)
json_line = json.dumps({"text": chunk})
jsonl_file.write(json_line + "\n")
print(f"Successfully processed {file_path} and created {output_file}")
except FileNotFoundError:
print(f"File not found: {file_path}")
except Exception as e:
print(f"An error occurred: {str(e)}")
if __name__ == "__main__":
if len(sys.argv) != 2:
print("Usage: python script.py <input_file>")
sys.exit(1)
input_file = sys.argv[1]
process_file(input_file, model, max_tokens)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment