Created
April 20, 2024 02:21
-
-
Save CoffeeVampir3/85acbda7ac5db055fa4a5aa04cae2b0d to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from transformers import AutoTokenizer | |
import json | |
import sys | |
model = "/home/blackroot/Desktop/llama3-8b/llama-3-8b" | |
max_tokens = 8192 | |
def count_tokens_hf(text: str, model_name: str) -> int: | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
encoded_input = tokenizer.encode(text) | |
num_tokens = len(encoded_input) | |
return num_tokens | |
def chunk_text_hf(text: str, model_name: str, max_tokens: int) -> list: | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
encoded_input = tokenizer.encode(text) | |
chunks = [] | |
start_idx = 0 | |
while start_idx < len(encoded_input): | |
end_idx = min(start_idx + max_tokens, len(encoded_input)) | |
if end_idx < len(encoded_input): | |
# Find the nearest stopping punctuation (., ?, or !) | |
while end_idx > start_idx and tokenizer.decode(encoded_input[end_idx]) not in [".", "?", "!"]: | |
end_idx -= 1 | |
# If no stopping punctuation found, include the entire chunk | |
if end_idx == start_idx: | |
end_idx = min(start_idx + max_tokens, len(encoded_input)) | |
chunk = tokenizer.decode(encoded_input[start_idx:end_idx+1]) | |
chunks.append(chunk) | |
start_idx = end_idx + 1 | |
return chunks | |
def process_file(file_path: str, model_name: str, max_tokens: int) -> None: | |
try: | |
with open(file_path, 'r', encoding='utf-8') as file: | |
text = file.read() | |
chunks = chunk_text_hf(text, model_name, max_tokens) | |
output_file = file_path.rsplit('.', 1)[0] + '.jsonl' | |
with open(output_file, 'w', encoding='utf-8') as jsonl_file: | |
for chunk in chunks: | |
chunk_size = count_tokens_hf(chunk, model_name) | |
print(chunk_size) | |
json_line = json.dumps({"text": chunk}) | |
jsonl_file.write(json_line + "\n") | |
print(f"Successfully processed {file_path} and created {output_file}") | |
except FileNotFoundError: | |
print(f"File not found: {file_path}") | |
except Exception as e: | |
print(f"An error occurred: {str(e)}") | |
if __name__ == "__main__": | |
if len(sys.argv) != 2: | |
print("Usage: python script.py <input_file>") | |
sys.exit(1) | |
input_file = sys.argv[1] | |
process_file(input_file, model, max_tokens) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment