Created
April 30, 2024 18:56
-
-
Save CoffeeVampir3/c74d32e83083dd9c7ba01efbbbd12a49 to your computer and use it in GitHub Desktop.
Exllama tokenization for train
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
for text, inst, summary in zip(text, instructions, previous_summary): | |
if summary != "": | |
summarized_augmentation = random.choice(augmented_continuation) | |
inst = f"{inst}\n\n{summarized_augmentation} {summary}" | |
next_prompt = copy.deepcopy(enc_sys_prompt) | |
next_message = encode_message(tokenizer, "user", inst) | |
next_prompt.extend(next_message) | |
mask_length = len(next_prompt) | |
next_prompt.extend(encode_message_english_sentence_truncate(tokenizer, "assistant", text, mask_length, 8150)) | |
total_token_length = len(next_prompt) | |
tokens = torch.tensor(next_prompt).unsqueeze(dim=0) | |
#print(tokenizer.decode(tokens[:, mask_length:], decode_special_tokens = True)) | |
#print(mask_length) | |
instruction = tokenizer.decode(tokens, decode_special_tokens = True) | |
data = {"instruction": instruction, "mask_token_length": mask_length, "total_token_length": total_token_length} | |
json_line = json.dumps(data) | |
result.append(json_line) | |
if a < 3: | |
print("*"*200) | |
print(inst) | |
a = a + 1 | |
with open(output_file, 'w') as file: | |
for line in result: | |
file.write(line + "\n") |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import sys, os, random, re | |
import torch | |
# A requirement for using exllamav2 api | |
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) | |
from exllamav2 import( | |
ExLlamaV2, | |
ExLlamaV2Config, | |
ExLlamaV2Cache, | |
ExLlamaV2Tokenizer, | |
) | |
from exllamav2.generator import ( | |
ExLlamaV2StreamingGenerator, | |
ExLlamaV2Sampler | |
) | |
def load_model(model_directory): | |
config = ExLlamaV2Config() | |
config.model_dir = model_directory | |
config.prepare() | |
#ratio = 4 | |
#alpha = -0.13436 + 0.80541 * ratio + 0.28833 * ratio ** 2 | |
#config.rope_alpha = alpha | |
config.max_seq_len = 8192*2 | |
config.max_attention_size = config.max_input_len**2 | |
model = ExLlamaV2(config) | |
cache = ExLlamaV2Cache(model, lazy = True) | |
model.load_autosplit(cache) | |
tokenizer = ExLlamaV2Tokenizer(config) | |
generator = ExLlamaV2StreamingGenerator(model, cache, tokenizer) | |
return config, tokenizer, cache, generator | |
def encode_system(tokenizer, system_prompt): | |
bos_token = tokenizer.single_id("<|begin_of_text|>") | |
eot_token = tokenizer.single_id("<|eot_id|>") | |
tokens = [bos_token] | |
tokens.extend(encode_header(tokenizer, "system")) | |
system_ids = tokenizer.encode(system_prompt, add_bos = False).view(-1).tolist() | |
tokens.extend(system_ids) | |
tokens.append(eot_token) | |
return tokens | |
def encode_header(tokenizer, username): | |
tokens = [] | |
start_header = tokenizer.single_id("<|start_header_id|>") | |
end_header = tokenizer.single_id("<|end_header_id|>") | |
tokens.append(start_header) | |
tokens.extend(tokenizer.encode(username, add_bos = False).view(-1).tolist()) | |
tokens.append(end_header) | |
tokens.extend(tokenizer.encode("\n\n", add_bos = False).view(-1).tolist()) | |
return tokens | |
def encode_header_prefilled(tokenizer, username, prefill): | |
tokens = [] | |
start_header = tokenizer.single_id("<|start_header_id|>") | |
end_header = tokenizer.single_id("<|end_header_id|>") | |
tokens.append(start_header) | |
tokens.extend(tokenizer.encode(username, add_bos = False).view(-1).tolist()) | |
tokens.append(end_header) | |
tokens.extend(tokenizer.encode("\n\n", add_bos = False).view(-1).tolist()) | |
tokens.extend(tokenizer.encode(prefill, add_bos = False).view(-1).tolist()) | |
return tokens | |
def encode_message(tokenizer, username, message): | |
eot_token = tokenizer.single_id("<|eot_id|>") | |
tokens = encode_header(tokenizer, username) | |
tokens.extend( | |
tokenizer.encode(message.strip(), add_bos = False).view(-1).tolist() | |
) | |
tokens.append(eot_token) | |
return tokens | |
def encode_message_english_sentence_truncate(tokenizer, username, message, current_length, threshold): | |
safe = False | |
olen = 0 | |
while not safe: | |
eot_token = tokenizer.single_id("<|eot_id|>") | |
tokens = encode_header(tokenizer, username) | |
tokens.extend( | |
tokenizer.encode(message.strip(), add_bos = False).view(-1).tolist() | |
) | |
token_length = (len(tokens) + current_length) | |
if token_length > threshold: | |
if olen == 0: | |
olen = token_length | |
# Remove the last sentence from the message | |
sentences = re.findall(r'[^.!?]+[.!?]', message) | |
if len(sentences) > 1: | |
message = ''.join(sentences[:-1]).strip() | |
else: | |
message = "" | |
else: | |
safe = True | |
tokens.append(eot_token) | |
if olen != 0: | |
print(f"Truncated len {olen} to new length {len(tokens) + current_length}") | |
return tokens | |
def encode_completion(tokenizer, message): | |
tokens = [] | |
tokens.extend( | |
tokenizer.encode(message.strip(), add_bos = False).view(-1).tolist() | |
) | |
return tokens | |
def encode_completion_prefilled(tokenizer, message, prefill): | |
tokens = [] | |
tokens.extend( | |
tokenizer.encode(message.strip(), add_bos = False).view(-1).tolist() | |
) | |
tokens.extend( | |
tokenizer.encode(prefill, add_bos = False).view(-1).tolist() | |
) | |
return tokens | |
@torch.inference_mode() | |
def generate_response_stream(instruction_ids, generator, tokenizer, settings, stop_sequences=[]): | |
#clear_cache(generator) | |
generator.begin_stream_ex(instruction_ids, settings) | |
#print("*"*500) | |
#print(tokenizer.decode(generator.sequence_ids, decode_special_tokens = True)) | |
stop_sequences.append(tokenizer.eos_token_id) | |
stop_sequences.append(128009) | |
generator.set_stop_conditions(stop_sequences) | |
while True: | |
res = generator.stream_ex() | |
chunk = res["chunk"] | |
ids = res["chunk_token_ids"] | |
counts = len(res["chunk_token_ids"]) | |
yield chunk, ids, counts | |
if res["eos"]: | |
return |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment