Last active
August 10, 2024 13:05
-
-
Save CoffeeVampir3/326ca3f3b7f3e081d41df24dd4d9623f to your computer and use it in GitHub Desktop.
Rolling system prompt
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from exllamav2 import( | |
ExLlamaV2Tokenizer | |
) | |
def encode_message(tokenizer: ExLlamaV2Tokenizer, role: str, message: str) -> list: | |
tokens = [] | |
start_token = tokenizer.single_id("<start_of_turn>") | |
end_token = tokenizer.single_id("<end_of_turn>") | |
tokens.append(start_token) | |
tokens.extend(tokenizer.encode(f"{role}\n", add_bos=False).view(-1).tolist()) | |
tokens.extend(tokenizer.encode(message.strip(), add_bos=False).view(-1).tolist()) | |
tokens.append(end_token) | |
tokens.extend(tokenizer.encode("\n", add_bos=False).view(-1).tolist()) | |
return tokens | |
def terminate_message(tokenizer: ExLlamaV2Tokenizer, message: str) -> list: | |
tokens = [] | |
end_token = tokenizer.single_id("<end_of_turn>") | |
tokens.extend(tokenizer.encode(message.strip(), add_bos=False).view(-1).tolist()) | |
tokens.append(end_token) | |
return tokens | |
def encode_header(tokenizer: ExLlamaV2Tokenizer, role: str) -> list: | |
tokens = [] | |
start_token = tokenizer.single_id("<start_of_turn>") | |
tokens.append(start_token) | |
tokens.extend(tokenizer.encode(f"{role}\n", add_bos=False).view(-1).tolist()) | |
return tokens |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from flask import Flask, render_template | |
import torch | |
from flask_socketio import SocketIO, emit | |
from generation.exllama_generator_wrapper import get_generator_func | |
from formats.gemma2_format import encode_header, encode_message, terminate_message | |
from collections import deque | |
import time | |
import os,sys,json | |
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) | |
from exllamav2 import ( | |
ExLlamaV2Lora, | |
) | |
app = Flask(__name__) | |
socketio = SocketIO(app) | |
sys_role = "model" | |
user_role = "user" | |
system_prompt = f"""Example.""" | |
model, tokenizer, generator, generate = get_generator_func(sys.argv[1]) | |
jsonl_file = 'conversation.jsonl' | |
def write_system_prompt(system_prompt): | |
with open(jsonl_file, 'w') as f: | |
json.dump({"system": system_prompt}, f) | |
f.write('\n') | |
lora = None | |
if len(sys.argv) > 2: | |
lora = ExLlamaV2Lora.from_directory(model, sys.argv[2]) | |
print(lora) | |
message_pairs = {} | |
turn_counter = 0 | |
current_ctx_size = 0 | |
ctx_buffer_trim_token_thresh = 30000 | |
@app.route('/') | |
def index(): | |
return render_template('index.html', system_prompt=system_prompt) | |
@socketio.on('purge_messages') | |
def handle_purge_messages(): | |
global message_pairs, turn_counter | |
message_pairs = {} | |
turn_counter = 0 | |
@socketio.on('send_message') | |
def handle_send_message(message, new_sys_prompt): | |
global generate, message_pairs, turn_counter, generator, sys_role, user_role, current_ctx_size, lora | |
emit('user_message', message, broadcast=True) | |
emit('start_generation', broadcast=True) # Emit event to indicate generation start | |
enc_sys_prompt = encode_message(tokenizer, "system", new_sys_prompt) | |
debug_out = torch.tensor(enc_sys_prompt).unsqueeze(dim=0) | |
debug_v = tokenizer.decode(debug_out, decode_special_tokens = True) | |
print(debug_v) | |
turn_counter += 1 | |
next_message = encode_message(tokenizer, user_role, message) | |
message_pairs[turn_counter] = [next_message, encode_header(tokenizer, sys_role)] | |
with open(jsonl_file, 'a') as f: | |
json.dump({user_role: message}, f) | |
f.write('\n') | |
if current_ctx_size > ctx_buffer_trim_token_thresh: | |
for _ in range(0, 20): | |
oldest_turn = min(message_pairs.keys()) | |
del message_pairs[oldest_turn] | |
sorted_turns = sorted(message_pairs.keys()) | |
for turn in sorted_turns: | |
messages = message_pairs[turn] | |
for msg in messages: | |
enc_sys_prompt.extend(msg) | |
testing = torch.tensor(enc_sys_prompt).unsqueeze(dim=0) | |
tokens_out = 0 | |
response = [] | |
print("*"*30) | |
message_buffer = "" | |
for fragment, ids, count in generate(instruction_ids=testing, generator=generator, loras=lora): | |
emit('stream_response', {'fragment': fragment}, broadcast=True) | |
tokens_out += count | |
message_buffer += fragment | |
#print(str(ids) + " ", flush=True, end="") | |
#if tokens_out + len(enc_sys_prompt) > 7500: | |
# break | |
enc_turn = terminate_message(tokenizer, message_buffer) | |
with open(jsonl_file, 'a') as f: | |
json.dump({sys_role: message_buffer}, f) | |
f.write('\n') | |
message_pairs[turn_counter].append(enc_turn) | |
current_ctx_size = len(enc_turn) + len(enc_sys_prompt) | |
dbg = [] | |
for item in message_pairs[turn_counter]: | |
dbg.extend(item) | |
debug_out = torch.tensor(dbg).unsqueeze(dim=0) | |
debug_v = tokenizer.decode(debug_out, decode_special_tokens = True) | |
print(debug_v) | |
emit('update_token_count', len(enc_sys_prompt) + tokens_out, broadcast=True) | |
emit('end_generation', broadcast=True) # Emit event to indicate generation end | |
if __name__ == '__main__': | |
socketio.run(app) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment