Skip to content

Instantly share code, notes, and snippets.

@CoffeeVampir3
Created April 30, 2024 18:56
Show Gist options
  • Save CoffeeVampir3/c74d32e83083dd9c7ba01efbbbd12a49 to your computer and use it in GitHub Desktop.
Save CoffeeVampir3/c74d32e83083dd9c7ba01efbbbd12a49 to your computer and use it in GitHub Desktop.
Exllama tokenization for train
for text, inst, summary in zip(text, instructions, previous_summary):
if summary != "":
summarized_augmentation = random.choice(augmented_continuation)
inst = f"{inst}\n\n{summarized_augmentation} {summary}"
next_prompt = copy.deepcopy(enc_sys_prompt)
next_message = encode_message(tokenizer, "user", inst)
next_prompt.extend(next_message)
mask_length = len(next_prompt)
next_prompt.extend(encode_message_english_sentence_truncate(tokenizer, "assistant", text, mask_length, 8150))
total_token_length = len(next_prompt)
tokens = torch.tensor(next_prompt).unsqueeze(dim=0)
#print(tokenizer.decode(tokens[:, mask_length:], decode_special_tokens = True))
#print(mask_length)
instruction = tokenizer.decode(tokens, decode_special_tokens = True)
data = {"instruction": instruction, "mask_token_length": mask_length, "total_token_length": total_token_length}
json_line = json.dumps(data)
result.append(json_line)
if a < 3:
print("*"*200)
print(inst)
a = a + 1
with open(output_file, 'w') as file:
for line in result:
file.write(line + "\n")
import sys, os, random, re
import torch
# A requirement for using exllamav2 api
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from exllamav2 import(
ExLlamaV2,
ExLlamaV2Config,
ExLlamaV2Cache,
ExLlamaV2Tokenizer,
)
from exllamav2.generator import (
ExLlamaV2StreamingGenerator,
ExLlamaV2Sampler
)
def load_model(model_directory):
config = ExLlamaV2Config()
config.model_dir = model_directory
config.prepare()
#ratio = 4
#alpha = -0.13436 + 0.80541 * ratio + 0.28833 * ratio ** 2
#config.rope_alpha = alpha
config.max_seq_len = 8192*2
config.max_attention_size = config.max_input_len**2
model = ExLlamaV2(config)
cache = ExLlamaV2Cache(model, lazy = True)
model.load_autosplit(cache)
tokenizer = ExLlamaV2Tokenizer(config)
generator = ExLlamaV2StreamingGenerator(model, cache, tokenizer)
return config, tokenizer, cache, generator
def encode_system(tokenizer, system_prompt):
bos_token = tokenizer.single_id("<|begin_of_text|>")
eot_token = tokenizer.single_id("<|eot_id|>")
tokens = [bos_token]
tokens.extend(encode_header(tokenizer, "system"))
system_ids = tokenizer.encode(system_prompt, add_bos = False).view(-1).tolist()
tokens.extend(system_ids)
tokens.append(eot_token)
return tokens
def encode_header(tokenizer, username):
tokens = []
start_header = tokenizer.single_id("<|start_header_id|>")
end_header = tokenizer.single_id("<|end_header_id|>")
tokens.append(start_header)
tokens.extend(tokenizer.encode(username, add_bos = False).view(-1).tolist())
tokens.append(end_header)
tokens.extend(tokenizer.encode("\n\n", add_bos = False).view(-1).tolist())
return tokens
def encode_header_prefilled(tokenizer, username, prefill):
tokens = []
start_header = tokenizer.single_id("<|start_header_id|>")
end_header = tokenizer.single_id("<|end_header_id|>")
tokens.append(start_header)
tokens.extend(tokenizer.encode(username, add_bos = False).view(-1).tolist())
tokens.append(end_header)
tokens.extend(tokenizer.encode("\n\n", add_bos = False).view(-1).tolist())
tokens.extend(tokenizer.encode(prefill, add_bos = False).view(-1).tolist())
return tokens
def encode_message(tokenizer, username, message):
eot_token = tokenizer.single_id("<|eot_id|>")
tokens = encode_header(tokenizer, username)
tokens.extend(
tokenizer.encode(message.strip(), add_bos = False).view(-1).tolist()
)
tokens.append(eot_token)
return tokens
def encode_message_english_sentence_truncate(tokenizer, username, message, current_length, threshold):
safe = False
olen = 0
while not safe:
eot_token = tokenizer.single_id("<|eot_id|>")
tokens = encode_header(tokenizer, username)
tokens.extend(
tokenizer.encode(message.strip(), add_bos = False).view(-1).tolist()
)
token_length = (len(tokens) + current_length)
if token_length > threshold:
if olen == 0:
olen = token_length
# Remove the last sentence from the message
sentences = re.findall(r'[^.!?]+[.!?]', message)
if len(sentences) > 1:
message = ''.join(sentences[:-1]).strip()
else:
message = ""
else:
safe = True
tokens.append(eot_token)
if olen != 0:
print(f"Truncated len {olen} to new length {len(tokens) + current_length}")
return tokens
def encode_completion(tokenizer, message):
tokens = []
tokens.extend(
tokenizer.encode(message.strip(), add_bos = False).view(-1).tolist()
)
return tokens
def encode_completion_prefilled(tokenizer, message, prefill):
tokens = []
tokens.extend(
tokenizer.encode(message.strip(), add_bos = False).view(-1).tolist()
)
tokens.extend(
tokenizer.encode(prefill, add_bos = False).view(-1).tolist()
)
return tokens
@torch.inference_mode()
def generate_response_stream(instruction_ids, generator, tokenizer, settings, stop_sequences=[]):
#clear_cache(generator)
generator.begin_stream_ex(instruction_ids, settings)
#print("*"*500)
#print(tokenizer.decode(generator.sequence_ids, decode_special_tokens = True))
stop_sequences.append(tokenizer.eos_token_id)
stop_sequences.append(128009)
generator.set_stop_conditions(stop_sequences)
while True:
res = generator.stream_ex()
chunk = res["chunk"]
ids = res["chunk_token_ids"]
counts = len(res["chunk_token_ids"])
yield chunk, ids, counts
if res["eos"]:
return
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment